diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index e1ffccc001b0..b92389a93b86 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -63,37 +63,37 @@ jobs: echo '::set-output name=branch::master' echo '::set-output name=type::scheduled' echo '::set-output name=envs::{"SCALA_PROFILE": "scala2.13"}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' elif [ "${{ github.event.schedule }}" = "0 7 * * *" ]; then echo '::set-output name=java::8' echo '::set-output name=branch::branch-3.2' echo '::set-output name=type::scheduled' echo '::set-output name=envs::{"SCALA_PROFILE": "scala2.13"}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' elif [ "${{ github.event.schedule }}" = "0 10 * * *" ]; then echo '::set-output name=java::8' echo '::set-output name=branch::master' echo '::set-output name=type::pyspark-coverage-scheduled' echo '::set-output name=envs::{"PYSPARK_CODECOV": "true"}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' elif [ "${{ github.event.schedule }}" = "0 13 * * *" ]; then echo '::set-output name=java::11' echo '::set-output name=branch::master' echo '::set-output name=type::scheduled' echo '::set-output name=envs::{"SKIP_MIMA": "true", "SKIP_UNIDOC": "true"}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' elif [ "${{ github.event.schedule }}" = "0 16 * * *" ]; then echo '::set-output name=java::17' echo '::set-output name=branch::master' echo '::set-output name=type::scheduled' echo '::set-output name=envs::{"SKIP_MIMA": "true", "SKIP_UNIDOC": "true"}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' else echo '::set-output name=java::8' echo '::set-output name=branch::master' # Default branch to run on. CHANGE here when a branch is cut out. echo '::set-output name=type::regular' echo '::set-output name=envs::{}' - echo '::set-output name=hadoop::hadoop3.3' + echo '::set-output name=hadoop::hadoop3.2' fi # Build: build Spark and run the tests for specified modules. diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 1ebd8bd89fd4..339870195044 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -463,14 +463,14 @@ object ResourceProfile extends Logging { case ResourceProfile.CORES => cores = execReq.amount.toInt case rName => - val nameToUse = resourceMappings.get(rName).getOrElse(rName) + val nameToUse = resourceMappings.getOrElse(rName, rName) customResources(nameToUse) = execReq } } customResources.toMap } else { defaultResources.customResources.map { case (rName, execReq) => - val nameToUse = resourceMappings.get(rName).getOrElse(rName) + val nameToUse = resourceMappings.getOrElse(rName, rName) (nameToUse, execReq) } } diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala index d538f0bcc423..2858443c7cd3 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala @@ -57,8 +57,10 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf, private val notRunningUnitTests = !isTesting private val testExceptionThrown = sparkConf.get(RESOURCE_PROFILE_MANAGER_TESTING) - // If we use anything except the default profile, its only supported on YARN right now. - // Throw an exception if not supported. + /** + * If we use anything except the default profile, it's only supported on YARN and Kubernetes + * with dynamic allocation enabled. Throw an exception if not supported. + */ private[spark] def isSupported(rp: ResourceProfile): Boolean = { val isNotDefaultProfile = rp.id != ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID val notYarnOrK8sAndNotDefaultProfile = isNotDefaultProfile && !(isYarn || isK8s) @@ -103,7 +105,7 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf, def resourceProfileFromId(rpId: Int): ResourceProfile = { readLock.lock() try { - resourceProfileIdToResourceProfile.get(rpId).getOrElse( + resourceProfileIdToResourceProfile.getOrElse(rpId, throw new SparkException(s"ResourceProfileId $rpId not found!") ) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 208c676a1c35..626a237732e3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -56,7 +56,7 @@ private[spark] class FetchFailedException( // which intercepts this exception (possibly wrapping it), the Executor can still tell there was // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option // because the TaskContext is not defined in some test cases. - Option(TaskContext.get()).map(_.setFetchFailed(this)) + Option(TaskContext.get()).foreach(_.setFetchFailed(this)) def toTaskFailedReason: TaskFailedReason = FetchFailed( bmAddress, shuffleId, mapId, mapIndex, reduceId, Utils.exceptionString(this)) diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index 76137133227f..d137099e7343 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -31,6 +31,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_CLEANUP, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID @@ -60,15 +61,17 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { val indexFile = r.getIndexFile(shuffleId, mapId) if (indexFile.exists()) { + val hash = JavaUtils.nonNegativeHash(indexFile.getName) fallbackFileSystem.copyFromLocalFile( new Path(indexFile.getAbsolutePath), - new Path(fallbackPath, s"$appId/$shuffleId/${indexFile.getName}")) + new Path(fallbackPath, s"$appId/$shuffleId/$hash/${indexFile.getName}")) val dataFile = r.getDataFile(shuffleId, mapId) if (dataFile.exists()) { + val hash = JavaUtils.nonNegativeHash(dataFile.getName) fallbackFileSystem.copyFromLocalFile( new Path(dataFile.getAbsolutePath), - new Path(fallbackPath, s"$appId/$shuffleId/${dataFile.getName}")) + new Path(fallbackPath, s"$appId/$shuffleId/$hash/${dataFile.getName}")) } // Report block statuses @@ -86,7 +89,8 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { } def exists(shuffleId: Int, filename: String): Boolean = { - fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$filename")) + val hash = JavaUtils.nonNegativeHash(filename) + fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$hash/$filename")) } } @@ -168,7 +172,8 @@ private[spark] object FallbackStorage extends Logging { } val name = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID).name - val indexFile = new Path(fallbackPath, s"$appId/$shuffleId/$name") + val hash = JavaUtils.nonNegativeHash(name) + val indexFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") val start = startReduceId * 8L val end = endReduceId * 8L Utils.tryWithResource(fallbackFileSystem.open(indexFile)) { inputStream => @@ -178,7 +183,8 @@ private[spark] object FallbackStorage extends Logging { index.skip(end - (start + 8L)) val nextOffset = index.readLong() val name = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID).name - val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$name") + val hash = JavaUtils.nonNegativeHash(name) + val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") val f = fallbackFileSystem.open(dataFile) val size = nextOffset - offset logDebug(s"To byte array $size") diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 88f8ac7f1fa6..44baeddb6f6c 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -192,7 +192,7 @@ SCALA_2_12_PROFILES="-Pscala-2.12" HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central # We use Apache Hive 2.3 for publishing -PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Phive-2.3 -Pspark-ganglia-lgpl -Pkinesis-asl -Phadoop-cloud" +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl -Phadoop-cloud" # Profiles for building binary releases BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" @@ -322,18 +322,18 @@ if [[ "$1" == "package" ]]; then # 'python/pyspark/install.py' and 'python/docs/source/getting_started/install.rst' # if you're changing them. declare -A BINARY_PKGS_ARGS - BINARY_PKGS_ARGS["hadoop3.3"]="-Phadoop-3 $HIVE_PROFILES" + BINARY_PKGS_ARGS["hadoop3.2"]="-Phadoop-3.2 $HIVE_PROFILES" if ! is_dry_run; then BINARY_PKGS_ARGS["without-hadoop"]="-Phadoop-provided" BINARY_PKGS_ARGS["hadoop2.7"]="-Phadoop-2.7 $HIVE_PROFILES" fi declare -A BINARY_PKGS_EXTRA - BINARY_PKGS_EXTRA["hadoop3.3"]="withpip,withr" + BINARY_PKGS_EXTRA["hadoop3.2"]="withpip,withr" if [[ $PUBLISH_SCALA_2_13 = 1 ]]; then - key="hadoop3.3-scala2.13" - args="-Phadoop-3 $HIVE_PROFILES" + key="hadoop3.2-scala2.13" + args="-Phadoop-3.2 $HIVE_PROFILES" extra="" if ! make_binary_release "$key" "$SCALA_2_13_PROFILES $args" "$extra" "2.13"; then error "Failed to build $key package. Check logs for details." diff --git a/dev/deps/spark-deps-hadoop-3.3-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 similarity index 100% rename from dev/deps/spark-deps-hadoop-3.3-hive-2.3 rename to dev/deps/spark-deps-hadoop-3.2-hive-2.3 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index ca616d7459c7..67d0972acc68 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -172,11 +172,8 @@ def main(): # Switch the Hadoop profile based on the PR title: if "test-hadoop2.7" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.7" - if "test-hadoop3.3" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop3.3" - # Switch the Hive profile based on the PR title: - if "test-hive2.3" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_HIVE_PROFILE"] = "hive2.3" + if "test-hadoop3.2" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop3.2" # Switch the Scala profile based on the PR title: if "test-scala2.13" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_SCALA_PROFILE"] = "scala2.13" diff --git a/dev/run-tests.py b/dev/run-tests.py index 3778123ae638..25df8f62aca4 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -334,7 +334,7 @@ def get_hadoop_profiles(hadoop_version): sbt_maven_hadoop_profiles = { "hadoop2.7": ["-Phadoop-2.7"], - "hadoop3.3": ["-Phadoop-3"], + "hadoop3.2": ["-Phadoop-3.2"], } if hadoop_version in sbt_maven_hadoop_profiles: @@ -345,24 +345,6 @@ def get_hadoop_profiles(hadoop_version): sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) -def get_hive_profiles(hive_version): - """ - For the given Hive version tag, return a list of Maven/SBT profile flags for - building and testing against that Hive version. - """ - - sbt_maven_hive_profiles = { - "hive2.3": ["-Phive-2.3"], - } - - if hive_version in sbt_maven_hive_profiles: - return sbt_maven_hive_profiles[hive_version] - else: - print("[error] Could not find", hive_version, "in the list. Valid options", - " are", sbt_maven_hive_profiles.keys()) - sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) - - def build_spark_maven(extra_profiles): # Enable all of the profiles for the build: build_profiles = extra_profiles + modules.root.build_profile_flags @@ -615,8 +597,7 @@ def main(): # to reflect the environment settings build_tool = os.environ.get("AMPLAB_JENKINS_BUILD_TOOL", "sbt") scala_version = os.environ.get("AMPLAB_JENKINS_BUILD_SCALA_PROFILE") - hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop3.3") - hive_version = os.environ.get("AMPLAB_JENKINS_BUILD_HIVE_PROFILE", "hive2.3") + hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop3.2") test_env = "amplab_jenkins" # add path for Python3 in Jenkins if we're calling from a Jenkins machine # TODO(sknapp): after all builds are ported to the ubuntu workers, change this to be: @@ -626,15 +607,13 @@ def main(): # else we're running locally or GitHub Actions. build_tool = "sbt" scala_version = os.environ.get("SCALA_PROFILE") - hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop3.3") - hive_version = os.environ.get("HIVE_PROFILE", "hive2.3") + hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop3.2") if "GITHUB_ACTIONS" in os.environ: test_env = "github_actions" else: test_env = "local" - extra_profiles = get_hadoop_profiles(hadoop_version) + get_hive_profiles(hive_version) + \ - get_scala_profiles(scala_version) + extra_profiles = get_hadoop_profiles(hadoop_version) + get_scala_profiles(scala_version) print("[info] Using build tool", build_tool, "with profiles", *(extra_profiles + ["under environment", test_env])) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d13be2e2fe29..5dd3ab616950 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -464,6 +464,7 @@ def __hash__(self): "pyspark.sql.tests.test_streaming", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", + "pyspark.sql.tests.test_udf_profiler", "pyspark.sql.tests.test_utils", ] ) @@ -606,6 +607,7 @@ def __hash__(self): "pyspark.pandas.namespace", "pyspark.pandas.numpy_compat", "pyspark.pandas.sql_processor", + "pyspark.pandas.sql_formatter", "pyspark.pandas.strings", "pyspark.pandas.utils", "pyspark.pandas.window", diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 5ecb305c0fe4..e23a0b682bcc 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -35,7 +35,7 @@ HADOOP_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkubernetes -Pyarn -Phive \ MVN="build/mvn" HADOOP_HIVE_PROFILES=( hadoop-2.7-hive-2.3 - hadoop-3.3-hive-2.3 + hadoop-3.2-hive-2.3 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to @@ -84,22 +84,20 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de # Generate manifests for each Hadoop profile: for HADOOP_HIVE_PROFILE in "${HADOOP_HIVE_PROFILES[@]}"; do - if [[ $HADOOP_HIVE_PROFILE == **hadoop-3.3-hive-2.3** ]]; then - HADOOP_PROFILE=hadoop-3 - HIVE_PROFILE=hive-2.3 + if [[ $HADOOP_HIVE_PROFILE == **hadoop-3.2-hive-2.3** ]]; then + HADOOP_PROFILE=hadoop-3.2 else HADOOP_PROFILE=hadoop-2.7 - HIVE_PROFILE=hive-2.3 fi echo "Performing Maven install for $HADOOP_HIVE_PROFILE" - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE jar:jar jar:test-jar install:install clean -q + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install clean -q echo "Performing Maven validate for $HADOOP_HIVE_PROFILE" - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE validate -q + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE validate -q echo "Generating dependency manifest for $HADOOP_HIVE_PROFILE" mkdir -p dev/pr-deps - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE dependency:build-classpath -pl assembly -am \ + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly -am \ | grep "Dependencies classpath:" -A 1 \ | tail -n 1 | tr ":" "\n" | awk -F '/' '{ # For each dependency classpath, we fetch the last three parts split by "/": artifact id, version, and jar name. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 9ad7ad62117c..3592f6be16a4 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -528,7 +528,7 @@ Below is a list of all the keywords in Spark SQL. |ROW|non-reserved|non-reserved|reserved| |ROWS|non-reserved|non-reserved|reserved| |SCHEMA|non-reserved|non-reserved|non-reserved| -|SCHEMAS|non-reserved|non-reserved|not a keyword| +|SCHEMAS|non-reserved|non-reserved|non-reserved| |SECOND|non-reserved|non-reserved|non-reserved| |SELECT|reserved|non-reserved|reserved| |SEMI|non-reserved|strict-non-reserved|non-reserved| diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index 438f63c75b87..c480fba121fd 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -387,11 +387,11 @@ private[kafka010] class KafkaOffsetReaderAdmin( // Calculate offset ranges val offsetRangesBase = untilPartitionOffsets.keySet.map { tp => - val fromOffset = fromPartitionOffsets.get(tp).getOrElse { + val fromOffset = fromPartitionOffsets.getOrElse(tp, // This should not happen since topicPartitions contains all partitions not in // fromPartitionOffsets throw new IllegalStateException(s"$tp doesn't have a from offset") - } + ) val untilOffset = untilPartitionOffsets(tp) KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 02021877c525..2e95b5778a80 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -190,7 +190,7 @@ @@ -201,7 +201,7 @@ enables store-specific committers. --> - hadoop-3 + hadoop-3.2 true diff --git a/pom.xml b/pom.xml index c93dbf7f8a5a..87e3489f101a 100644 --- a/pom.xml +++ b/pom.xml @@ -3349,15 +3349,10 @@ - hadoop-3 + hadoop-3.2 - - hive-2.3 - - - yarn diff --git a/python/docs/source/development/debugging.rst b/python/docs/source/development/debugging.rst index 829919858f67..1e6571da0289 100644 --- a/python/docs/source/development/debugging.rst +++ b/python/docs/source/development/debugging.rst @@ -277,4 +277,58 @@ executor side, which can be enabled by setting ``spark.python.profile`` configur 12 0.000 0.000 0.001 0.000 context.py:506(f) ... -This feature is supported only with RDD APIs. +Python/Pandas UDF +~~~~~~~~~~~~~~~~~ + +To use this on Python/Pandas UDFs, PySpark provides remote `Python Profilers `_ for +Python/Pandas UDFs, which can be enabled by setting ``spark.python.profile`` configuration to ``true``. + +.. code-block:: bash + + pyspark --conf spark.python.profile=true + + +.. code-block:: python + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.range(10) + >>> @pandas_udf("long") + ... def add1(x): + ... return x + 1 + ... + >>> added = df.select(add1("id")) + + >>> added.show() + +--------+ + |add1(id)| + +--------+ + ... + +--------+ + + >>> sc.show_profiles() + ============================================================ + Profile of UDF + ============================================================ + 2300 function calls (2270 primitive calls) in 0.006 seconds + + Ordered by: internal time, cumulative time + + ncalls tottime percall cumtime percall filename:lineno(function) + 10 0.001 0.000 0.005 0.001 series.py:5515(_arith_method) + 10 0.001 0.000 0.001 0.000 _ufunc_config.py:425(__init__) + 10 0.000 0.000 0.000 0.000 {built-in method _operator.add} + 10 0.000 0.000 0.002 0.000 series.py:315(__init__) + ... + +The UDF IDs can be seen in the query plan, for example, ``add1(...)#2L`` in ``ArrowEvalPython`` below. + +.. code-block:: python + + >>> added.explain() + == Physical Plan == + *(2) Project [pythonUDF0#11L AS add1(id)#3L] + +- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200 + +- *(1) Range (0, 10, step=1, splits=16) + + +This feature is not supported with registered UDFs. diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 13c6f8f3a28e..601b45d00a7c 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -154,11 +154,11 @@ Dependencies ============= ========================= ====================================== Package Minimum supported version Note ============= ========================= ====================================== -`pandas` 0.23.2 Optional for Spark SQL +`pandas` 1.0.5 Optional for Spark SQL `NumPy` 1.7 Required for MLlib DataFrame-based API `pyarrow` 1.0.0 Optional for Spark SQL `Py4J` 0.10.9.2 Required -`pandas` 0.23.2 Required for pandas API on Spark +`pandas` 1.0.5 Required for pandas API on Spark `pyarrow` 1.0.0 Required for pandas API on Spark `Numpy` 1.14 Required for pandas API on Spark ============= ========================= ====================================== diff --git a/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst b/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst index 060f24c8f41f..f2701d4fb721 100644 --- a/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst +++ b/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst @@ -20,4 +20,6 @@ Upgrading from PySpark 3.2 to 3.3 ================================= +* In Spark 3.3, the ``pyspark.pandas.sql`` method follows [the standard Python string formatter](https://docs.python.org/3/library/string.html#format-string-syntax). To restore the previous behavior, set ``PYSPARK_PANDAS_SQL_LEGACY`` environment variable to ``1``. * In Spark 3.3, the ``drop`` method of pandas API on Spark DataFrame supports dropping rows by ``index``, and sets dropping by index instead of column by default. +* In Spark 3.3, PySpark upgrades Pandas version, the new minimum required version changes from 0.23.2 to 1.0.5. diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index bb84202f165f..04bfe27c247f 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -148,6 +148,7 @@ Computations / Descriptive Stats DataFrame.clip DataFrame.corr DataFrame.count + DataFrame.cov DataFrame.describe DataFrame.kurt DataFrame.kurtosis diff --git a/python/docs/source/user_guide/pandas_on_spark/options.rst b/python/docs/source/user_guide/pandas_on_spark/options.rst index 8f18f8ef8eb7..06a27ecbe8aa 100644 --- a/python/docs/source/user_guide/pandas_on_spark/options.rst +++ b/python/docs/source/user_guide/pandas_on_spark/options.rst @@ -286,7 +286,10 @@ compute.eager_check True 'compute.eager_check' sets whethe performs the validation beforehand, but it will cause a performance overhead. Otherwise, pandas-on-Spark skip the validation and will be slightly different - from pandas. Affected APIs: `Series.dot`. + from pandas. Affected APIs: `Series.dot`, + `Series.asof`, `FractionalExtensionOps.astype`, + `IntegralExtensionOps.astype`, `FractionalOps.astype`, + `DecimalOps.astype`. compute.isin_limit 80 'compute.isin_limit' sets the limit for filtering by 'Column.isin(list)'. If the length of the ‘list’ is above the limit, broadcast join is used instead for diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 78d3e7ad84e3..20a9f935d586 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -387,7 +387,7 @@ working with timestamps in ``pandas_udf``\s to get the best performance, see Recommended Pandas and PyArrow Versions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For usage with pyspark.sql, the minimum supported versions of Pandas is 0.23.2 and PyArrow is 1.0.0. +For usage with pyspark.sql, the minimum supported versions of Pandas is 1.0.5 and PyArrow is 1.0.0. Higher versions may be used, however, compatibility and data correctness can not be guaranteed and should be verified by the user. diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 70392fb1df48..aab95aded046 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -57,7 +57,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast -from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.serializers import MarshalSerializer, CPickleSerializer from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ @@ -136,7 +136,7 @@ def wrapper(self, *args, **kwargs): "Accumulator", "AccumulatorParam", "MarshalSerializer", - "PickleSerializer", + "CPickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi index 35df545ee64b..fb045f2e5c54 100644 --- a/python/pyspark/__init__.pyi +++ b/python/pyspark/__init__.pyi @@ -38,7 +38,7 @@ from pyspark.profiler import ( # noqa: F401 from pyspark.rdd import RDD as RDD, RDDBarrier as RDDBarrier # noqa: F401 from pyspark.serializers import ( # noqa: F401 MarshalSerializer as MarshalSerializer, - PickleSerializer as PickleSerializer, + CPickleSerializer as CPickleSerializer, ) from pyspark.status import ( # noqa: F401 SparkJobInfo as SparkJobInfo, diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index c43ebe417b54..2ea2a4952e0a 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -20,13 +20,13 @@ import struct import socketserver as SocketServer import threading -from pyspark.serializers import read_int, PickleSerializer +from pyspark.serializers import read_int, CPickleSerializer __all__ = ["Accumulator", "AccumulatorParam"] -pickleSer = PickleSerializer() +pickleSer = CPickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. diff --git a/python/pyspark/cloudpickle/__init__.py b/python/pyspark/cloudpickle/__init__.py index 56506d95fa1b..0ae79b5535c8 100644 --- a/python/pyspark/cloudpickle/__init__.py +++ b/python/pyspark/cloudpickle/__init__.py @@ -8,4 +8,4 @@ # expose their Pickler subclass at top-level under the "Pickler" name. Pickler = CloudPickler -__version__ = '1.6.0' +__version__ = '2.0.0' diff --git a/python/pyspark/cloudpickle/cloudpickle.py b/python/pyspark/cloudpickle/cloudpickle.py index 05d52afa0da9..347b38695803 100644 --- a/python/pyspark/cloudpickle/cloudpickle.py +++ b/python/pyspark/cloudpickle/cloudpickle.py @@ -55,6 +55,7 @@ import warnings from .compat import pickle +from collections import OrderedDict from typing import Generic, Union, Tuple, Callable from pickle import _getattribute from importlib._bootstrap import _find_spec @@ -87,8 +88,11 @@ def g(): # communication speed over compatibility: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +# Names of modules whose resources should be treated as dynamic. +_PICKLE_BY_VALUE_MODULES = set() + # Track the provenance of reconstructed dynamic classes to make it possible to -# recontruct instances from the matching singleton class definition when +# reconstruct instances from the matching singleton class definition when # appropriate and preserve the usual "isinstance" semantics of Python objects. _DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() _DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() @@ -123,6 +127,77 @@ def _lookup_class_or_track(class_tracker_id, class_def): return class_def +def register_pickle_by_value(module): + """Register a module to make it functions and classes picklable by value. + + By default, functions and classes that are attributes of an importable + module are to be pickled by reference, that is relying on re-importing + the attribute from the module at load time. + + If `register_pickle_by_value(module)` is called, all its functions and + classes are subsequently to be pickled by value, meaning that they can + be loaded in Python processes where the module is not importable. + + This is especially useful when developing a module in a distributed + execution environment: restarting the client Python process with the new + source code is enough: there is no need to re-install the new version + of the module on all the worker nodes nor to restart the workers. + + Note: this feature is considered experimental. See the cloudpickle + README.md file for more details and limitations. + """ + if not isinstance(module, types.ModuleType): + raise ValueError( + f"Input should be a module object, got {str(module)} instead" + ) + # In the future, cloudpickle may need a way to access any module registered + # for pickling by value in order to introspect relative imports inside + # functions pickled by value. (see + # https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633). + # This access can be ensured by checking that module is present in + # sys.modules at registering time and assuming that it will still be in + # there when accessed during pickling. Another alternative would be to + # store a weakref to the module. Even though cloudpickle does not implement + # this introspection yet, in order to avoid a possible breaking change + # later, we still enforce the presence of module inside sys.modules. + if module.__name__ not in sys.modules: + raise ValueError( + f"{module} was not imported correctly, have you used an " + f"`import` statement to access it?" + ) + _PICKLE_BY_VALUE_MODULES.add(module.__name__) + + +def unregister_pickle_by_value(module): + """Unregister that the input module should be pickled by value.""" + if not isinstance(module, types.ModuleType): + raise ValueError( + f"Input should be a module object, got {str(module)} instead" + ) + if module.__name__ not in _PICKLE_BY_VALUE_MODULES: + raise ValueError(f"{module} is not registered for pickle by value") + else: + _PICKLE_BY_VALUE_MODULES.remove(module.__name__) + + +def list_registry_pickle_by_value(): + return _PICKLE_BY_VALUE_MODULES.copy() + + +def _is_registered_pickle_by_value(module): + module_name = module.__name__ + if module_name in _PICKLE_BY_VALUE_MODULES: + return True + while True: + parent_name = module_name.rsplit(".", 1)[0] + if parent_name == module_name: + break + if parent_name in _PICKLE_BY_VALUE_MODULES: + return True + module_name = parent_name + return False + + def _whichmodule(obj, name): """Find the module an object belongs to. @@ -136,11 +211,14 @@ def _whichmodule(obj, name): # Workaround bug in old Python versions: prior to Python 3.7, # T.__module__ would always be set to "typing" even when the TypeVar T # would be defined in a different module. - # - # For such older Python versions, we ignore the __module__ attribute of - # TypeVar instances and instead exhaustively lookup those instances in - # all currently imported modules. - module_name = None + if name is not None and getattr(typing, name, None) is obj: + # Built-in TypeVar defined in typing such as AnyStr + return 'typing' + else: + # User defined or third-party TypeVar: __module__ attribute is + # irrelevant, thus trigger a exhaustive search for obj in all + # modules. + module_name = None else: module_name = getattr(obj, '__module__', None) @@ -166,18 +244,35 @@ def _whichmodule(obj, name): return None -def _is_importable(obj, name=None): - """Dispatcher utility to test the importability of various constructs.""" - if isinstance(obj, types.FunctionType): - return _lookup_module_and_qualname(obj, name=name) is not None - elif issubclass(type(obj), type): - return _lookup_module_and_qualname(obj, name=name) is not None +def _should_pickle_by_reference(obj, name=None): + """Test whether an function or a class should be pickled by reference + + Pickling by reference means by that the object (typically a function or a + class) is an attribute of a module that is assumed to be importable in the + target Python environment. Loading will therefore rely on importing the + module and then calling `getattr` on it to access the function or class. + + Pickling by reference is the only option to pickle functions and classes + in the standard library. In cloudpickle the alternative option is to + pickle by value (for instance for interactively or locally defined + functions and classes or for attributes of modules that have been + explicitly registered to be pickled by value. + """ + if isinstance(obj, types.FunctionType) or issubclass(type(obj), type): + module_and_name = _lookup_module_and_qualname(obj, name=name) + if module_and_name is None: + return False + module, name = module_and_name + return not _is_registered_pickle_by_value(module) + elif isinstance(obj, types.ModuleType): # We assume that sys.modules is primarily used as a cache mechanism for # the Python import machinery. Checking if a module has been added in - # is sys.modules therefore a cheap and simple heuristic to tell us whether - # we can assume that a given module could be imported by name in - # another Python process. + # is sys.modules therefore a cheap and simple heuristic to tell us + # whether we can assume that a given module could be imported by name + # in another Python process. + if _is_registered_pickle_by_value(obj): + return False return obj.__name__ in sys.modules else: raise TypeError( @@ -233,10 +328,13 @@ def _extract_code_globals(co): out_names = _extract_code_globals_cache.get(co) if out_names is None: names = co.co_names - out_names = {names[oparg] for _, oparg in _walk_global_ops(co)} + # We use a dict with None values instead of a set to get a + # deterministic order (assuming Python 3.6+) and avoid introducing + # non-deterministic pickle bytes as a results. + out_names = {names[oparg]: None for _, oparg in _walk_global_ops(co)} # Declaring a function inside another one using the "def ..." - # syntax generates a constant code object corresonding to the one + # syntax generates a constant code object corresponding to the one # of the nested function's As the nested function may itself need # global variables, we need to introspect its code, extract its # globals, (look for code object in it's co_consts attribute..) and @@ -244,7 +342,7 @@ def _extract_code_globals(co): if co.co_consts: for const in co.co_consts: if isinstance(const, types.CodeType): - out_names |= _extract_code_globals(const) + out_names.update(_extract_code_globals(const)) _extract_code_globals_cache[co] = out_names @@ -452,15 +550,31 @@ def _extract_class_dict(cls): if sys.version_info[:2] < (3, 7): # pragma: no branch def _is_parametrized_type_hint(obj): - # This is very cheap but might generate false positives. + # This is very cheap but might generate false positives. So try to + # narrow it down is good as possible. + type_module = getattr(type(obj), '__module__', None) + from_typing_extensions = type_module == 'typing_extensions' + from_typing = type_module == 'typing' + # general typing Constructs is_typing = getattr(obj, '__origin__', None) is not None # typing_extensions.Literal - is_litteral = getattr(obj, '__values__', None) is not None + is_literal = ( + (getattr(obj, '__values__', None) is not None) + and from_typing_extensions + ) # typing_extensions.Final - is_final = getattr(obj, '__type__', None) is not None + is_final = ( + (getattr(obj, '__type__', None) is not None) + and from_typing_extensions + ) + + # typing.ClassVar + is_classvar = ( + (getattr(obj, '__type__', None) is not None) and from_typing + ) # typing.Union/Tuple for old Python 3.5 is_union = getattr(obj, '__union_params__', None) is not None @@ -469,8 +583,8 @@ def _is_parametrized_type_hint(obj): getattr(obj, '__result__', None) is not None and getattr(obj, '__args__', None) is not None ) - return any((is_typing, is_litteral, is_final, is_union, is_tuple, - is_callable)) + return any((is_typing, is_literal, is_final, is_classvar, is_union, + is_tuple, is_callable)) def _create_parametrized_type_hint(origin, args): return origin[args] @@ -557,8 +671,11 @@ def _rebuild_tornado_coroutine(func): loads = pickle.loads -# hack for __import__ not working as desired def subimport(name): + # We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is + # the name of a submodule, __import__ will return the top-level root module + # of this submodule. For instance, __import__('os.path') returns the `os` + # module. __import__(name) return sys.modules[name] @@ -699,7 +816,7 @@ def _make_skel_func(code, cell_count, base_globals=None): """ # This function is deprecated and should be removed in cloudpickle 1.7 warnings.warn( - "A pickle file created using an old (<=1.4.1) version of cloudpicke " + "A pickle file created using an old (<=1.4.1) version of cloudpickle " "is currently being loaded. This is not supported by cloudpickle and " "will break in cloudpickle 1.7", category=UserWarning ) @@ -813,10 +930,15 @@ def _decompose_typevar(obj): def _typevar_reduce(obj): - # TypeVar instances have no __qualname__ hence we pass the name explicitly. + # TypeVar instances require the module information hence why we + # are not using the _should_pickle_by_reference directly module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__) + if module_and_name is None: return (_make_typevar, _decompose_typevar(obj)) + elif _is_registered_pickle_by_value(module_and_name[0]): + return (_make_typevar, _decompose_typevar(obj)) + return (getattr, module_and_name) @@ -830,13 +952,22 @@ def _get_bases(typ): return getattr(typ, bases_attr) -def _make_dict_keys(obj): - return dict.fromkeys(obj).keys() +def _make_dict_keys(obj, is_ordered=False): + if is_ordered: + return OrderedDict.fromkeys(obj).keys() + else: + return dict.fromkeys(obj).keys() -def _make_dict_values(obj): - return {i: _ for i, _ in enumerate(obj)}.values() +def _make_dict_values(obj, is_ordered=False): + if is_ordered: + return OrderedDict((i, _) for i, _ in enumerate(obj)).values() + else: + return {i: _ for i, _ in enumerate(obj)}.values() -def _make_dict_items(obj): - return obj.items() +def _make_dict_items(obj, is_ordered=False): + if is_ordered: + return OrderedDict(obj).items() + else: + return obj.items() diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py b/python/pyspark/cloudpickle/cloudpickle_fast.py index fa8da0f635c4..6db059eb858b 100644 --- a/python/pyspark/cloudpickle/cloudpickle_fast.py +++ b/python/pyspark/cloudpickle/cloudpickle_fast.py @@ -6,7 +6,7 @@ is only available for Python versions 3.8+, a lot of backward-compatibility code is also removed. -Note that the C Pickler sublassing API is CPython-specific. Therefore, some +Note that the C Pickler subclassing API is CPython-specific. Therefore, some guards present in cloudpickle.py that were written to handle PyPy specificities are not present in cloudpickle_fast.py """ @@ -23,12 +23,12 @@ import typing from enum import Enum -from collections import ChainMap +from collections import ChainMap, OrderedDict from .compat import pickle, Pickler from .cloudpickle import ( _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL, - _find_imported_submodules, _get_cell_contents, _is_importable, + _find_imported_submodules, _get_cell_contents, _should_pickle_by_reference, _builtin_type, _get_or_create_tracker_id, _make_skeleton_class, _make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType, @@ -180,7 +180,7 @@ def _class_getstate(obj): clsdict.pop('__weakref__', None) if issubclass(type(obj), abc.ABCMeta): - # If obj is an instance of an ABCMeta subclass, dont pickle the + # If obj is an instance of an ABCMeta subclass, don't pickle the # cache/negative caches populated during isinstance/issubclass # checks, but pickle the list of registered subclasses of obj. clsdict.pop('_abc_cache', None) @@ -244,7 +244,19 @@ def _enum_getstate(obj): def _code_reduce(obj): """codeobject reducer""" - if hasattr(obj, "co_posonlyargcount"): # pragma: no branch + if hasattr(obj, "co_linetable"): # pragma: no branch + # Python 3.10 and later: obj.co_lnotab is deprecated and constructor + # expects obj.co_linetable instead. + args = ( + obj.co_argcount, obj.co_posonlyargcount, + obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, + obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_linetable, obj.co_freevars, + obj.co_cellvars + ) + elif hasattr(obj, "co_posonlyargcount"): + # Backward compat for 3.9 and older args = ( obj.co_argcount, obj.co_posonlyargcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, @@ -254,6 +266,7 @@ def _code_reduce(obj): obj.co_cellvars ) else: + # Backward compat for even older versions of Python args = ( obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, @@ -339,11 +352,16 @@ def _memoryview_reduce(obj): def _module_reduce(obj): - if _is_importable(obj): + if _should_pickle_by_reference(obj): return subimport, (obj.__name__,) else: - obj.__dict__.pop('__builtins__', None) - return dynamic_subimport, (obj.__name__, vars(obj)) + # Some external libraries can populate the "__builtins__" entry of a + # module's `__dict__` with unpicklable objects (see #316). For that + # reason, we do not attempt to pickle the "__builtins__" entry, and + # restore a default value for it at unpickling time. + state = obj.__dict__.copy() + state.pop('__builtins__', None) + return dynamic_subimport, (obj.__name__, state) def _method_reduce(obj): @@ -396,7 +414,7 @@ def _class_reduce(obj): return type, (NotImplemented,) elif obj in _BUILTIN_TYPE_NAMES: return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],) - elif not _is_importable(obj): + elif not _should_pickle_by_reference(obj): return _dynamic_class_reduce(obj) return NotImplemented @@ -419,6 +437,24 @@ def _dict_items_reduce(obj): return _make_dict_items, (dict(obj), ) +def _odict_keys_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_keys, (list(obj), True) + + +def _odict_values_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_values, (list(obj), True) + + +def _odict_items_reduce(obj): + return _make_dict_items, (dict(obj), True) + + # COLLECTIONS OF OBJECTS STATE SETTERS # ------------------------------------ # state setters are called at unpickling time, once the object is created and @@ -426,7 +462,7 @@ def _dict_items_reduce(obj): def _function_setstate(obj, state): - """Update the state of a dynaamic function. + """Update the state of a dynamic function. As __closure__ and __globals__ are readonly attributes of a function, we cannot rely on the native setstate routine of pickle.load_build, that calls @@ -495,6 +531,9 @@ class CloudPickler(Pickler): _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce + _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce + _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce + _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table) @@ -520,7 +559,7 @@ def _function_reduce(self, obj): As opposed to cloudpickle.py, There no special handling for builtin pypy functions because cloudpickle_fast is CPython-specific. """ - if _is_importable(obj): + if _should_pickle_by_reference(obj): return NotImplemented else: return self._dynamic_function_reduce(obj) @@ -579,7 +618,7 @@ def dump(self, obj): # `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler # used `CloudPickler.dispatch` as a class-level attribute storing all # reducers implemented by cloudpickle, but the attribute name was not a - # great choice given the meaning of `Cloudpickler.dispatch` when + # great choice given the meaning of `CloudPickler.dispatch` when # `CloudPickler` extends the pure-python pickler. dispatch = dispatch_table @@ -653,7 +692,7 @@ def reducer_override(self, obj): return self._function_reduce(obj) else: # fallback to save_global, including the Pickler's - # distpatch_table + # dispatch_table return NotImplemented else: @@ -724,7 +763,7 @@ def save_global(self, obj, name=None, pack=struct.pack): ) elif name is not None: Pickler.save_global(self, obj, name=name) - elif not _is_importable(obj, name=name): + elif not _should_pickle_by_reference(obj, name=name): self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj) else: Pickler.save_global(self, obj, name=name) @@ -736,7 +775,7 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ - if _is_importable(obj, name=name): + if _should_pickle_by_reference(obj, name=name): return Pickler.save_global(self, obj, name=name) elif PYPY and isinstance(obj.__code__, builtin_code_type): return self.save_pypy_builtin_func(obj) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2c789947af91..336024fff808 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -35,7 +35,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway, local_connect_and_auth from pyspark.serializers import ( - PickleSerializer, + CPickleSerializer, BatchedSerializer, UTF8Deserializer, PairDeserializer, @@ -49,7 +49,7 @@ from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker -from pyspark.profiler import ProfilerCollector, BasicProfiler +from pyspark.profiler import ProfilerCollector, BasicProfiler, UDFBasicProfiler __all__ = ["SparkContext"] @@ -105,6 +105,9 @@ class SparkContext(object): profiler_cls : type, optional A class of custom Profiler used to do profiling (default is :class:`pyspark.profiler.BasicProfiler`). + udf_profiler_cls : type, optional + A class of custom Profiler used to do udf profiling + (default is :class:`pyspark.profiler.UDFBasicProfiler`). Notes ----- @@ -142,11 +145,12 @@ def __init__( pyFiles=None, environment=None, batchSize=0, - serializer=PickleSerializer(), + serializer=CPickleSerializer(), conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler, + udf_profiler_cls=UDFBasicProfiler, ): if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true": # In order to prevent SparkContext from being created in executors. @@ -172,6 +176,7 @@ def __init__( conf, jsc, profiler_cls, + udf_profiler_cls, ) except: # If an error occurs, clean up in order to allow future SparkContext creation: @@ -190,6 +195,7 @@ def _do_init( conf, jsc, profiler_cls, + udf_profiler_cls, ): self.environment = environment or {} # java gateway must have been launched at this point. @@ -319,7 +325,7 @@ def _do_init( # profiling stats collected for each PythonRDD if self._conf.get("spark.python.profile", "false") == "true": dump_path = self._conf.get("spark.python.profile.dump", None) - self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + self.profiler_collector = ProfilerCollector(profiler_cls, udf_profiler_cls, dump_path) else: self.profiler_collector = None @@ -814,7 +820,7 @@ def sequenceFile( and value Writable classes 2. Serialization is attempted via Pickle pickling 3. If this fails, the fallback is to call 'toString' on each key and value - 4. :class:`PickleSerializer` is used to deserialize pickled objects on the Python side + 4. :class:`CPickleSerializer` is used to deserialize pickled objects on the Python side Parameters ---------- diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 640a69cad08a..f1350aaec94c 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -62,6 +62,7 @@ class SparkContext: gateway: Optional[JavaGateway] = ..., jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., + udf_profiler_cls: type = ..., ) -> None: ... def __getnewargs__(self) -> NoReturn: ... def __enter__(self) -> SparkContext: ... diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index 194a492aff7c..61b20f131d23 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -27,7 +27,7 @@ import pyspark.context from pyspark import RDD, SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession # Hack for support float('inf') in Py4j @@ -65,7 +65,7 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject: It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) # type: ignore[attr-defined] + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) # type: ignore[attr-defined] @@ -84,7 +84,7 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject: elif isinstance(obj, (int, float, bool, bytes, str)): pass else: - data = bytearray(PickleSerializer().dumps(obj)) + data = bytearray(CPickleSerializer().dumps(obj)) obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) # type: ignore[attr-defined] return obj @@ -113,7 +113,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt pass # not picklable if isinstance(r, (bytearray, bytes)): - r = PickleSerializer().loads(bytes(r), encoding=encoding) + r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 5db6c048bfcc..dfdd32e98ebf 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -20,7 +20,7 @@ from numpy import arange, array, array_equal, inf, ones, tile, zeros -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.ml.linalg import ( DenseMatrix, DenseVector, @@ -37,7 +37,7 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): - ser = PickleSerializer() + ser = CPickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 540545dc1114..5f109be2a176 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -27,7 +27,7 @@ import pyspark.context from pyspark import RDD, SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession # Hack for support float('inf') in Py4j @@ -67,7 +67,7 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject: It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) # type: ignore[attr-defined] + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) # type: ignore[attr-defined] @@ -86,7 +86,7 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject: elif isinstance(obj, (int, float, bool, bytes, str)): pass else: - data = bytearray(PickleSerializer().dumps(obj)) + data = bytearray(CPickleSerializer().dumps(obj)) obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) # type: ignore[attr-defined] return obj @@ -115,7 +115,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt pass # not pickable if isinstance(r, (bytearray, bytes)): - r = PickleSerializer().loads(bytes(r), encoding=encoding) + r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 6927b75e3d44..fd9f348f31bf 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -26,7 +26,7 @@ from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.testing.mllibutils import MLlibTestCase @@ -303,7 +303,7 @@ def test_regression(self): class ALSTests(MLlibTestCase): def test_als_ratings_serialize(self): - ser = PickleSerializer() + ser = CPickleSerializer() r = Rating(7, 1123, 3.14) jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) @@ -312,7 +312,7 @@ def test_als_ratings_serialize(self): self.assertAlmostEqual(r.rating, nr.rating, 2) def test_als_ratings_id_long_error(self): - ser = PickleSerializer() + ser = CPickleSerializer() r = Rating(1205640308657491975, 50233468418, 1.0) # rating user id exceeds max int value, should fail when pickled self.assertRaises( diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index e43482dc415e..d60396b633c2 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -21,7 +21,7 @@ from numpy import array, array_equal, zeros, arange, tile, ones, inf import pyspark.ml.linalg as newlinalg -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.mllib.linalg import ( # type: ignore[attr-defined] Vector, SparseVector, @@ -43,7 +43,7 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): - ser = PickleSerializer() + ser = CPickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) @@ -512,7 +512,7 @@ class SciPyTests(MLlibTestCase): def test_serialize(self): from scipy.sparse import lil_matrix - ser = PickleSerializer() + ser = CPickleSerializer() lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 diff --git a/python/pyspark/pandas/__init__.py b/python/pyspark/pandas/__init__.py index ea8a9ea63935..04128ed84c73 100644 --- a/python/pyspark/pandas/__init__.py +++ b/python/pyspark/pandas/__init__.py @@ -144,4 +144,4 @@ def _auto_patch_pandas() -> None: # Import after the usage logger is attached. from pyspark.pandas.config import get_option, options, option_context, reset_option, set_option from pyspark.pandas.namespace import * # F405 -from pyspark.pandas.sql_processor import sql +from pyspark.pandas.sql_formatter import sql diff --git a/python/pyspark/pandas/config.py b/python/pyspark/pandas/config.py index a6689c8fdeee..8e5c8081095e 100644 --- a/python/pyspark/pandas/config.py +++ b/python/pyspark/pandas/config.py @@ -201,7 +201,8 @@ def validate(self, v: Any) -> None: "of validation. If 'compute.eager_check' is set to True, pandas-on-Spark performs the " "validation beforehand, but it will cause a performance overhead. Otherwise, " "pandas-on-Spark skip the validation and will be slightly different from pandas. " - "Affected APIs: `Series.dot`." + "Affected APIs: `Series.dot`, `Series.asof`, `FractionalExtensionOps.astype`, " + "`IntegralExtensionOps.astype`, `FractionalOps.astype`, `DecimalOps.astype`." ), default=True, types=bool, diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index e08d6e9abbe8..f9e068fd2c3f 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -24,6 +24,7 @@ from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op +from pyspark.pandas.config import get_option from pyspark.pandas.data_type_ops.base import ( DataTypeOps, is_valid_operand_for_numeric_arithmetic, @@ -388,7 +389,7 @@ def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> Ind dtype, spark_type = pandas_on_spark_type(dtype) if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: + if get_option("compute.eager_check") and index_ops.hasnans: raise ValueError( "Cannot convert %s with missing values to integer" % self.pretty_name ) @@ -449,7 +450,7 @@ def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: + if get_option("compute.eager_check") and index_ops.hasnans: raise ValueError( "Cannot convert %s with missing values to integer" % self.pretty_name ) @@ -490,15 +491,17 @@ def restore(self, col: pd.Series) -> pd.Series: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) - - if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError( - "Cannot convert %s with missing values to integer" % self.pretty_name - ) - elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError("Cannot convert %s with missing values to bool" % self.pretty_name) + if get_option("compute.eager_check"): + if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to integer" % self.pretty_name + ) + elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to bool" % self.pretty_name + ) return _non_fractional_astype(index_ops, dtype, spark_type) @@ -517,15 +520,17 @@ def restore(self, col: pd.Series) -> pd.Series: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) - - if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError( - "Cannot convert %s with missing values to integer" % self.pretty_name - ) - elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError("Cannot convert %s with missing values to bool" % self.pretty_name) + if get_option("compute.eager_check"): + if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to integer" % self.pretty_name + ) + elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to bool" % self.pretty_name + ) if isinstance(dtype, CategoricalDtype): return _as_categorical_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 38ac9af9c163..edfb62ef2876 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -77,6 +77,7 @@ StringType, StructField, StructType, + DecimalType, ) from pyspark.sql.window import Window @@ -8258,6 +8259,194 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True) internal = self._internal.with_new_sdf(sdf, data_fields=data_fields) self._update_internal_frame(internal, requires_same_anchor=False) + # TODO: ddof should be implemented. + def cov(self, min_periods: Optional[int] = None) -> "DataFrame": + """ + Compute pairwise covariance of columns, excluding NA/null values. + + Compute the pairwise covariance among the series of a DataFrame. + The returned data frame is the `covariance matrix + `__ of the columns + of the DataFrame. + + Both NA and null values are automatically excluded from the + calculation. (See the note below about bias from missing values.) + A threshold can be set for the minimum number of + observations for each value created. Comparisons with observations + below this threshold will be returned as ``NaN``. + + This method is generally used for the analysis of time series data to + understand the relationship between different measures + across time. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. + + Returns + ------- + DataFrame + The covariance matrix of the series of the DataFrame. + + See Also + -------- + Series.cov : Compute covariance with another Series. + + Examples + -------- + >>> df = ps.DataFrame([(1, 2), (0, 3), (2, 0), (1, 1)], + ... columns=['dogs', 'cats']) + >>> df.cov() + dogs cats + dogs 0.666667 -1.000000 + cats -1.000000 1.666667 + + >>> np.random.seed(42) + >>> df = ps.DataFrame(np.random.randn(1000, 5), + ... columns=['a', 'b', 'c', 'd', 'e']) + >>> df.cov() + a b c d e + a 0.998438 -0.020161 0.059277 -0.008943 0.014144 + b -0.020161 1.059352 -0.008543 -0.024738 0.009826 + c 0.059277 -0.008543 1.010670 -0.001486 -0.000271 + d -0.008943 -0.024738 -0.001486 0.921297 -0.013692 + e 0.014144 0.009826 -0.000271 -0.013692 0.977795 + + **Minimum number of periods** + + This method also supports an optional ``min_periods`` keyword + that specifies the required minimum number of non-NA observations for + each column pair in order to have a valid result: + + >>> np.random.seed(42) + >>> df = pd.DataFrame(np.random.randn(20, 3), + ... columns=['a', 'b', 'c']) + >>> df.loc[df.index[:5], 'a'] = np.nan + >>> df.loc[df.index[5:10], 'b'] = np.nan + >>> sdf = ps.from_pandas(df) + >>> sdf.cov(min_periods=12) + a b c + a 0.316741 NaN -0.150812 + b NaN 1.248003 0.191417 + c -0.150812 0.191417 0.895202 + """ + min_periods = 1 if min_periods is None else min_periods + + # Only compute covariance for Boolean and Numeric except Decimal + psdf = self[ + [ + col + for col in self.columns + if isinstance(self[col].spark.data_type, BooleanType) + or ( + isinstance(self[col].spark.data_type, NumericType) + and not isinstance(self[col].spark.data_type, DecimalType) + ) + ] + ] + + num_cols = len(psdf.columns) + cov = np.zeros([num_cols, num_cols]) + + if num_cols == 0: + return DataFrame() + + if len(psdf) < min_periods: + cov.fill(np.nan) + return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + + data_cols = psdf._internal.data_spark_column_names + cov_scols = [] + count_not_null_scols = [] + + # Count number of null row between two columns + # Example: + # a b c + # 0 1 1 1 + # 1 NaN 2 2 + # 2 3 NaN 3 + # 3 4 4 4 + # + # a b c + # a count(a, a) count(a, b) count(a, c) + # b count(b, b) count(b, c) + # c count(c, c) + # + # count_not_null_scols = + # [F.count(a, a), F.count(a, b), F.count(a, c), F.count(b, b), F.count(b, c), F.count(c, c)] + for r in range(0, num_cols): + for c in range(r, num_cols): + count_not_null_scols.append( + F.count( + F.when(F.col(data_cols[r]).isNotNull() & F.col(data_cols[c]).isNotNull(), 1) + ) + ) + + count_not_null = ( + psdf._internal.spark_frame.replace(float("nan"), None) + .select(*count_not_null_scols) + .head(1)[0] + ) + + # Calculate covariance between two columns + # Example: + # with min_periods = 3 + # a b c + # 0 1 1 1 + # 1 NaN 2 2 + # 2 3 NaN 3 + # 3 4 4 4 + # + # a b c + # a cov(a, a) None cov(a, c) + # b cov(b, b) cov(b, c) + # c cov(c, c) + # + # cov_scols = [F.cov(a, a), None, F.cov(a, c), F.cov(b, b), F.cov(b, c), F.cov(c, c)] + step = 0 + for r in range(0, num_cols): + step += r + for c in range(r, num_cols): + cov_scols.append( + F.covar_samp( + F.col(data_cols[r]).cast("double"), F.col(data_cols[c]).cast("double") + ) + if count_not_null[r * num_cols + c - step] >= min_periods + else F.lit(None) + ) + + pair_cov = psdf._internal.spark_frame.select(*cov_scols).head(1)[0] + + # Convert from row to 2D array + # Example: + # pair_cov = [cov(a, a), None, cov(a, c), cov(b, b), cov(b, c), cov(c, c)] + # + # cov = + # + # a b c + # a cov(a, a) None cov(a, c) + # b cov(b, b) cov(b, c) + # c cov(c, c) + step = 0 + for r in range(0, num_cols): + step += r + for c in range(r, num_cols): + cov[r][c] = pair_cov[r * num_cols + c - step] + + # Copy values + # Example: + # cov = + # a b c + # a cov(a, a) None cov(a, c) + # b None cov(b, b) cov(b, c) + # c cov(a, c) cov(b, c) cov(c, c) + cov = cov + cov.T - np.diag(np.diag(cov)) + return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + def sample( self, n: Optional[int] = None, diff --git a/python/pyspark/pandas/missing/frame.py b/python/pyspark/pandas/missing/frame.py index aabc0e042e73..d822c1419247 100644 --- a/python/pyspark/pandas/missing/frame.py +++ b/python/pyspark/pandas/missing/frame.py @@ -39,7 +39,6 @@ class _MissingPandasLikeDataFrame(object): compare = _unsupported_function("compare") convert_dtypes = _unsupported_function("convert_dtypes") corrwith = _unsupported_function("corrwith") - cov = _unsupported_function("cov") ewm = _unsupported_function("ewm") infer_objects = _unsupported_function("infer_objects") interpolate = _unsupported_function("interpolate") diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index f6ec5e943a4c..4a459b6c33a4 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -5160,7 +5160,7 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: If there is no good value, NaN is returned. .. note:: This API is dependent on :meth:`Index.is_monotonic_increasing` - which can be expensive. + which is expensive. Parameters ---------- @@ -5179,7 +5179,9 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: Notes ----- - Indices are assumed to be sorted. Raises if this is not the case. + Indices are assumed to be sorted. Raises if this is not the case and config + 'compute.eager_check' is True. If 'compute.eager_check' is False pandas-on-Spark just + proceeds and performs by ignoring the indeces's order Examples -------- @@ -5210,13 +5212,19 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: >>> s.asof(30) 2.0 + + >>> s = ps.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) + >>> with ps.option_context("compute.eager_check", False): + ... s.asof(20) + ... + 1.0 """ should_return_series = True if isinstance(self.index, ps.MultiIndex): raise ValueError("asof is not supported for a MultiIndex") if isinstance(where, (ps.Index, ps.Series, DataFrame)): raise ValueError("where cannot be an Index, Series or a DataFrame") - if not self.index.is_monotonic_increasing: + if get_option("compute.eager_check") and not self.index.is_monotonic_increasing: raise ValueError("asof requires a sorted index") if not is_list_like(where): should_return_series = False diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py new file mode 100644 index 000000000000..685ee25cc669 --- /dev/null +++ b/python/pyspark/pandas/sql_formatter.py @@ -0,0 +1,273 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import string +from typing import Any, Optional, Union, List, Sequence, Mapping, Tuple +import uuid +import warnings + +import pandas as pd + +from pyspark.pandas.internal import InternalFrame +from pyspark.pandas.namespace import _get_index_map +from pyspark.sql.functions import lit +from pyspark import pandas as ps +from pyspark.sql import SparkSession +from pyspark.pandas.utils import default_session +from pyspark.pandas.frame import DataFrame +from pyspark.pandas.series import Series + + +__all__ = ["sql"] + + +# This is not used in this file. It's for legacy sql_processor. +_CAPTURE_SCOPES = 3 + + +def sql( + query: str, + index_col: Optional[Union[str, List[str]]] = None, + **kwargs: Any, +) -> DataFrame: + """ + Execute a SQL query and return the result as a pandas-on-Spark DataFrame. + + This function acts as a standard Python string formatter with understanding + the following variable types: + + * pandas-on-Spark DataFrame + * pandas-on-Spark Series + * pandas DataFrame + * pandas Series + * string + + Parameters + ---------- + query : str + the SQL query + index_col : str or list of str, optional + Column names to be used in Spark to represent pandas-on-Spark's index. The index name + in pandas-on-Spark is ignored. By default, the index is always lost. + + .. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`, + and pass it to the sql statement with `index_col` parameter. + + For example, + + >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) + >>> new_psdf = psdf.reset_index() + >>> ps.sql("SELECT * FROM {new_psdf}", index_col="index", new_psdf=new_psdf) + ... # doctest: +NORMALIZE_WHITESPACE + A B + index + a 1 4 + b 2 5 + c 3 6 + + For MultiIndex, + + >>> psdf = ps.DataFrame( + ... {"A": [1, 2, 3], "B": [4, 5, 6]}, + ... index=pd.MultiIndex.from_tuples( + ... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"] + ... ), + ... ) + >>> new_psdf = psdf.reset_index() + >>> ps.sql("SELECT * FROM {new_psdf}", index_col=["index1", "index2"], new_psdf=new_psdf) + ... # doctest: +NORMALIZE_WHITESPACE + A B + index1 index2 + a b 1 4 + c d 2 5 + e f 3 6 + + Also note that the index name(s) should be matched to the existing name. + kwargs + other variables that the user want to set that can be referenced in the query + + Returns + ------- + pandas-on-Spark DataFrame + + Examples + -------- + + Calling a built-in SQL function. + + >>> ps.sql("SELECT * FROM range(10) where id > 7") + id + 0 8 + 1 9 + + >>> ps.sql("SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9) + id + 0 8 + + >>> mydf = ps.range(10) + >>> x = tuple(range(4)) + >>> ps.sql("SELECT {ser} FROM {mydf} WHERE id IN {x}", ser=mydf.id, mydf=mydf, x=x) + id + 0 0 + 1 1 + 2 2 + 3 3 + + Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is + dropped. + + >>> ps.sql(''' + ... SELECT m1.a, m2.b + ... FROM {table1} m1 INNER JOIN {table2} m2 + ... ON m1.key = m2.key + ... ORDER BY m1.a, m2.b''', + ... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}), + ... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]})) + a b + 0 1 3 + 1 2 4 + 2 2 5 + + Also, it is possible to query using Series. + + >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) + >>> ps.sql("SELECT {mydf.A} FROM {mydf}", mydf=psdf) + A + 0 1 + 1 2 + 2 3 + """ + if os.environ.get("PYSPARK_PANDAS_SQL_LEGACY") == "1": + from pyspark.pandas import sql_processor + + warnings.warn( + "Deprecated in 3.3.0, and the legacy behavior " + "will be removed in the future releases.", + FutureWarning, + ) + return sql_processor.sql(query, index_col=index_col, **kwargs) + + session = default_session() + formatter = SQLStringFormatter(session) + try: + sdf = session.sql(formatter.format(query, **kwargs)) + finally: + formatter.clear() + + index_spark_columns, index_names = _get_index_map(sdf, index_col) + + return DataFrame( + InternalFrame( + spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names + ) + ) + + +class SQLStringFormatter(string.Formatter): + """ + A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances + with basic Python objects. This object has to be clear after the use for single SQL + query; cannot be reused across multiple SQL queries without cleaning. + """ + + def __init__(self, session: SparkSession) -> None: + self._session: SparkSession = session + self._temp_views: List[Tuple[DataFrame, str]] = [] + self._ref_sers: List[Tuple[Series, str]] = [] + + def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str: + ret = super(SQLStringFormatter, self).vformat(format_string, args, kwargs) + + for ref, n in self._ref_sers: + if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views): + # If referred DataFrame does not hold the given Series, raise an error. + raise ValueError("The series in {%s} does not refer any dataframe specified." % n) + return ret + + def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: + obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + return self._convert_value(obj, field_name), first + + def _convert_value(self, val: Any, name: str) -> Optional[str]: + """ + Converts the given value into a SQL string. + """ + if isinstance(val, pd.Series): + # Return the column name from pandas Series directly. + return ps.from_pandas(val).to_frame()._to_spark().columns[0] + elif isinstance(val, Series): + # Return the column name of pandas-on-Spark Series iff its DataFrame was + # referred. The check will be done in `vformat` after we parse all. + self._ref_sers.append((val, name)) + return val.to_frame()._to_spark().columns[0] + elif isinstance(val, (DataFrame, pd.DataFrame)): + df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "") + + if isinstance(val, pd.DataFrame): + # Don't store temp view for plain pandas instances + # because it is unable to know which pandas DataFrame + # holds which Series. + val = ps.from_pandas(val) + else: + for df, n in self._temp_views: + if df is val: + return n + self._temp_views.append((val, df_name)) + + val._to_spark().createOrReplaceTempView(df_name) + return df_name + elif isinstance(val, str): + return lit(val)._jc.expr().sql() # for escaped characters. + else: + return val + + def clear(self) -> None: + for _, n in self._temp_views: + self._session.catalog.dropTempView(n) + self._temp_views = [] + self._ref_sers = [] + + +def _test() -> None: + import os + import doctest + import sys + from pyspark.sql import SparkSession + import pyspark.pandas.sql_formatter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.pandas.sql_formatter.__dict__.copy() + globs["ps"] = pyspark.pandas + spark = ( + SparkSession.builder.master("local[4]") + .appName("pyspark.pandas.sql_processor tests") + .getOrCreate() + ) + (failure_count, test_count) = doctest.testmod( + pyspark.pandas.sql_formatter, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, + ) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/pandas/sql_processor.py b/python/pyspark/pandas/sql_processor.py index afdaa101d679..8126d1e10a44 100644 --- a/python/pyspark/pandas/sql_processor.py +++ b/python/pyspark/pandas/sql_processor.py @@ -77,9 +77,13 @@ def sql( For example, + >>> from pyspark.pandas import sql_processor + >>> # we will call 'sql_processor' directly in doctests so decrease one level. + >>> sql_processor._CAPTURE_SCOPES = 2 + >>> sql = sql_processor.sql >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) >>> psdf_reset_index = psdf.reset_index() - >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index") + >>> sql("SELECT * FROM {psdf_reset_index}", index_col="index") ... # doctest: +NORMALIZE_WHITESPACE A B index @@ -96,7 +100,7 @@ def sql( ... ), ... ) >>> psdf_reset_index = psdf.reset_index() - >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"]) + >>> sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"]) ... # doctest: +NORMALIZE_WHITESPACE A B index1 index2 @@ -122,7 +126,7 @@ def sql( Calling a built-in SQL function. - >>> ps.sql("select * from range(10) where id > 7") + >>> sql("select * from range(10) where id > 7") id 0 8 1 9 @@ -130,7 +134,7 @@ def sql( A query can also reference a local variable or parameter by wrapping them in curly braces: >>> bound1 = 7 - >>> ps.sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9) + >>> sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9) id 0 8 @@ -139,7 +143,7 @@ def sql( >>> mydf = ps.range(10) >>> x = range(4) - >>> ps.sql("SELECT * from {mydf} WHERE id IN {x}") + >>> sql("SELECT * from {mydf} WHERE id IN {x}") id 0 0 1 1 @@ -150,7 +154,7 @@ def sql( >>> def statement(): ... mydf2 = ps.DataFrame({"x": range(2)}) - ... return ps.sql("SELECT * from {mydf2}") + ... return sql("SELECT * from {mydf2}") >>> statement() x 0 0 @@ -159,7 +163,7 @@ def sql( Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is dropped. - >>> ps.sql(''' + >>> sql(''' ... SELECT m1.a, m2.b ... FROM {table1} m1 INNER JOIN {table2} m2 ... ON m1.key = m2.key @@ -174,7 +178,7 @@ def sql( Also, it is possible to query using Series. >>> myser = ps.Series({'a': [1.0, 2.0, 3.0], 'b': [15.0, 30.0, 45.0]}) - >>> ps.sql("SELECT * from {myser}") + >>> sql("SELECT * from {myser}") 0 0 [1.0, 2.0, 3.0] 1 [15.0, 30.0, 45.0] @@ -195,7 +199,7 @@ def sql( return SQLProcessor(_dict, query, default_session()).execute(index_col) -_CAPTURE_SCOPES = 2 +_CAPTURE_SCOPES = 3 def _get_local_scope() -> Dict[str, Any]: @@ -272,19 +276,23 @@ def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame: Returns a DataFrame for which the SQL statement has been executed by the underlying SQL engine. + >>> from pyspark.pandas import sql_processor + >>> # we will call 'sql_processor' directly in doctests so decrease one level. + >>> sql_processor._CAPTURE_SCOPES = 2 + >>> sql = sql_processor.sql >>> str0 = 'abc' - >>> ps.sql("select {str0}") + >>> sql("select {str0}") abc 0 abc >>> str1 = 'abc"abc' >>> str2 = "abc'abc" - >>> ps.sql("select {str0}, {str1}, {str2}") + >>> sql("select {str0}, {str1}, {str2}") abc abc"abc abc'abc 0 abc abc"abc abc'abc >>> strs = ['a', 'b'] - >>> ps.sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2") + >>> sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2") cond1 cond2 0 True False """ diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index f4b36f969a05..77fc93c0eba4 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -369,6 +369,23 @@ def test_astype(self): psser = ps.from_pandas(pser) self.assert_eq(pser.astype(pd.BooleanDtype()), psser.astype(pd.BooleanDtype())) + def test_astype_eager_check(self): + psser = self.psdf["float_nan"] + with ps.option_context("compute.eager_check", True), self.assertRaisesRegex( + ValueError, "Cannot convert" + ): + psser.astype(int) + with ps.option_context("compute.eager_check", False): + psser.astype(int) + + psser = self.psdf["decimal_nan"] + with ps.option_context("compute.eager_check", True), self.assertRaisesRegex( + ValueError, "Cannot convert" + ): + psser.astype(int) + with ps.option_context("compute.eager_check", False): + psser.astype(int) + def test_neg(self): pdf, psdf = self.pdf, self.psdf for col in self.numeric_df_cols: @@ -475,21 +492,26 @@ def test_astype(self): for pser, psser in self.intergral_extension_pser_psser_pairs: self.assert_eq(pser.astype(float), psser.astype(float)) self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to bool", - lambda: psser.astype(bool), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to integer", - lambda: psser.astype(int), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to integer", - lambda: psser.astype(np.int32), - ) + with ps.option_context("compute.eager_check", True): + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to bool", + lambda: psser.astype(bool), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to integer", + lambda: psser.astype(int), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to integer", + lambda: psser.astype(np.int32), + ) + with ps.option_context("compute.eager_check", False): + psser.astype(bool) + psser.astype(int) + psser.astype(np.int32) def test_neg(self): for pser, psser in self.intergral_extension_pser_psser_pairs: @@ -607,21 +629,26 @@ def test_astype(self): for pser, psser in self.fractional_extension_pser_psser_pairs: self.assert_eq(pser.astype(float), psser.astype(float)) self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to bool", - lambda: psser.astype(bool), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to integer", - lambda: psser.astype(int), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to integer", - lambda: psser.astype(np.int32), - ) + with ps.option_context("compute.eager_check", True): + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to bool", + lambda: psser.astype(bool), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to integer", + lambda: psser.astype(int), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to integer", + lambda: psser.astype(np.int32), + ) + with ps.option_context("compute.eager_check", False): + psser.astype(bool) + psser.astype(int) + psser.astype(np.int32) def test_neg(self): # pandas raises "TypeError: bad operand type for unary -: 'FloatingArray'" diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 701052ed2ce3..ae8fcaef89b7 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import decimal from datetime import datetime from distutils.version import LooseVersion import inspect @@ -6025,6 +6025,69 @@ def test_multi_index_dtypes(self): ) self.assert_eq(psmidx.dtypes, expected) + def test_cov(self): + # SPARK-36396: Implement DataFrame.cov + + # int + pdf = pd.DataFrame([(1, 2), (0, 3), (2, 0), (1, 1)], columns=["a", "b"]) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # bool + pdf = pd.DataFrame( + { + "a": [1, np.nan, 3, 4], + "b": [True, False, False, True], + "c": [True, True, False, True], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # extension dtype + numeric_dtypes = ["Int8", "Int16", "Int32", "Int64", "Float32", "Float64", "float"] + boolean_dtypes = ["boolean", "bool"] + + sers = [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in numeric_dtypes] + sers += [pd.Series([True, False, True, None], dtype=dtype) for dtype in boolean_dtypes] + sers.append(pd.Series([decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3), None])) + + pdf = pd.concat(sers, axis=1) + pdf.columns = [dtype for dtype in numeric_dtypes + boolean_dtypes] + ["decimal"] + psdf = ps.from_pandas(pdf) + + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=3), psdf.cov(min_periods=3), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4)) + + # string column + pdf = pd.DataFrame( + [(1, 2, "a", 1), (0, 3, "b", 1), (2, 0, "c", 9), (1, 1, "d", 1)], + columns=["a", "b", "c", "d"], + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # nan + np.random.seed(42) + pdf = pd.DataFrame(np.random.randn(20, 3), columns=["a", "b", "c"]) + pdf.loc[pdf.index[:5], "a"] = np.nan + pdf.loc[pdf.index[5:10], "b"] = np.nan + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(min_periods=11), psdf.cov(min_periods=11), almost=True) + self.assert_eq(pdf.cov(min_periods=10), psdf.cov(min_periods=10), almost=True) + + # return empty DataFrame + pdf = pd.DataFrame([("1", "2"), ("0", "3"), ("2", "0"), ("1", "1")], columns=["a", "b"]) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov()) + if __name__ == "__main__": from pyspark.pandas.tests.test_dataframe import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 72677d18e4b8..51c26ad8301b 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -2115,6 +2115,19 @@ def test_asof(self): self.assert_eq(psser.asof("2014-01-02"), pser.asof("2014-01-02")) self.assert_eq(repr(psser.asof("1999-01-02")), repr(pser.asof("1999-01-02"))) + # SPARK-37482: Skip check monotonic increasing for Series.asof with 'compute.eager_check' + pser = pd.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) + psser = ps.from_pandas(pser) + + with ps.option_context("compute.eager_check", False): + self.assert_eq(psser.asof(20), 1.0) + + pser = pd.Series([1, 2, np.nan, 4], index=[40, 30, 20, 10]) + psser = ps.from_pandas(pser) + + with ps.option_context("compute.eager_check", False): + self.assert_eq(psser.asof(20), 4.0) + def test_squeeze(self): # Single value pser = pd.Series([90]) @@ -2232,7 +2245,9 @@ def test_mad(self): pser.index = pmidx psser = ps.from_pandas(pser) - self.assert_eq(pser.mad(), psser.mad()) + # Mark almost as True to avoid precision issue like: + # "21.555555555555554 != 21.555555555555557" + self.assert_eq(pser.mad(), psser.mad(), almost=True) def test_to_frame(self): pser = pd.Series(["a", "b", "c"]) diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py index 306ea166cf93..ca0dd99a3209 100644 --- a/python/pyspark/pandas/tests/test_sql.py +++ b/python/pyspark/pandas/tests/test_sql.py @@ -23,20 +23,22 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils): def test_error_variable_not_exist(self): - msg = "The key variable_foo in the SQL statement was not found.*" - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(KeyError, "variable_foo"): ps.sql("select * from {variable_foo}") def test_error_unsupported_type(self): - msg = "Unsupported variable type dict: {'a': 1}" - with self.assertRaisesRegex(ValueError, msg): - some_dict = {"a": 1} + with self.assertRaisesRegex(KeyError, "some_dict"): ps.sql("select * from {some_dict}") def test_error_bad_sql(self): with self.assertRaises(ParseException): ps.sql("this is not valid sql") + def test_series_not_referred(self): + psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + with self.assertRaisesRegex(ValueError, "The series in {ser}"): + ps.sql("SELECT {ser} FROM range(10)", ser=psdf.A) + def test_sql_with_index_col(self): import pandas as pd @@ -45,7 +47,11 @@ def test_sql_with_index_col(self): {"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], name="index") ) psdf_reset_index = psdf.reset_index() - actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col="index") + actual = ps.sql( + "select * from {psdf_reset_index} where A > 1", + index_col="index", + psdf_reset_index=psdf_reset_index, + ) expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected) @@ -58,11 +64,40 @@ def test_sql_with_index_col(self): ) psdf_reset_index = psdf.reset_index() actual = ps.sql( - "select * from {psdf_reset_index} where A > 1", index_col=["index1", "index2"] + "select * from {psdf_reset_index} where A > 1", + index_col=["index1", "index2"], + psdf_reset_index=psdf_reset_index, ) expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected) + def test_sql_with_pandas_objects(self): + import pandas as pd + + pdf = pd.DataFrame({"a": [1, 2, 3, 4]}) + self.assert_eq(ps.sql("SELECT {col} + 1 as a FROM {tbl}", col=pdf.a, tbl=pdf), pdf + 1) + + def test_sql_with_python_objects(self): + self.assert_eq( + ps.sql("SELECT {col} as a FROM range(1)", col="lit"), ps.DataFrame({"a": ["lit"]}) + ) + self.assert_eq( + ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="lit", pred=(1, 2, 3)), + ps.DataFrame({"id": [1, 2, 3]}), + ) + + def test_sql_with_pandas_on_spark_objects(self): + psdf = ps.DataFrame({"a": [1, 2, 3, 4]}) + + self.assert_eq(ps.sql("SELECT {col} FROM {tbl}", col=psdf.a, tbl=psdf), psdf) + self.assert_eq(ps.sql("SELECT {tbl.a} FROM {tbl}", tbl=psdf), psdf) + + psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + self.assert_eq( + ps.sql("SELECT {col}, {col2} FROM {tbl}", col=psdf.A, col2=psdf.B, tbl=psdf), psdf + ) + self.assert_eq(ps.sql("SELECT {tbl.A}, {tbl.B} FROM {tbl}", tbl=psdf), psdf) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/usage_logging/__init__.py b/python/pyspark/pandas/usage_logging/__init__.py index ebd23ac6376a..b350faf6b9ca 100644 --- a/python/pyspark/pandas/usage_logging/__init__.py +++ b/python/pyspark/pandas/usage_logging/__init__.py @@ -25,7 +25,7 @@ import pandas as pd -from pyspark.pandas import config, namespace, sql_processor +from pyspark.pandas import config, namespace, sql_formatter from pyspark.pandas.accessors import PandasOnSparkFrameMethods from pyspark.pandas.frame import DataFrame from pyspark.pandas.datetimes import DatetimeMethods @@ -113,8 +113,8 @@ def attach(logger_module: Union[str, ModuleType]) -> None: except ImportError: pass - sql_processor._CAPTURE_SCOPES = 3 - modules.append(sql_processor) + sql_formatter._CAPTURE_SCOPES = 4 + modules.append(sql_formatter) # Modules for target_module in modules: diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 778b23fe2b90..3d3a65a46c7b 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -27,12 +27,13 @@ class ProfilerCollector(object): """ This class keeps track of different profilers on a per - stage basis. Also this is used to create new profilers for - the different stages. + stage/UDF basis. Also this is used to create new profilers for + the different stages/UDFs. """ - def __init__(self, profiler_cls, dump_path=None): + def __init__(self, profiler_cls, udf_profiler_cls, dump_path=None): self.profiler_cls = profiler_cls + self.udf_profiler_cls = udf_profiler_cls self.profile_dump_path = dump_path self.profilers = [] @@ -40,8 +41,12 @@ def new_profiler(self, ctx): """Create a new profiler using class `profiler_cls`""" return self.profiler_cls(ctx) + def new_udf_profiler(self, ctx): + """Create a new profiler using class `udf_profiler_cls`""" + return self.udf_profiler_cls(ctx) + def add_profiler(self, id, profiler): - """Add a profiler for RDD `id`""" + """Add a profiler for RDD/UDF `id`""" if not self.profilers: if self.profile_dump_path: atexit.register(self.dump_profiles, self.profile_dump_path) @@ -106,7 +111,7 @@ class Profiler(object): def __init__(self, ctx): pass - def profile(self, func): + def profile(self, func, *args, **kwargs): """Do profiling on the function `func`""" raise NotImplementedError @@ -160,10 +165,10 @@ def __init__(self, ctx): # partitions of a stage self._accumulator = ctx.accumulator(None, PStatsParam) - def profile(self, func): + def profile(self, func, *args, **kwargs): """Runs and profiles the method to_profile passed in. A profile object is returned.""" pr = cProfile.Profile() - pr.runcall(func) + ret = pr.runcall(func, *args, **kwargs) st = pstats.Stats(pr) st.stream = None # make it picklable st.strip_dirs() @@ -171,10 +176,36 @@ def profile(self, func): # Adds a new profile to the existing accumulated value self._accumulator.add(st) + return ret + def stats(self): return self._accumulator.value +class UDFBasicProfiler(BasicProfiler): + """ + UDFBasicProfiler is the profiler for Python/Pandas UDFs. + """ + + def show(self, id): + """Print the profile stats to stdout, id is the PythonUDF id""" + stats = self.stats() + if stats: + print("=" * 60) + print("Profile of UDF" % id) + print("=" * 60) + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """Dump the profile into path, id is the PythonUDF id""" + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "udf_%d.pstats" % id) + stats.dump_stats(p) + + if __name__ == "__main__": import doctest diff --git a/python/pyspark/profiler.pyi b/python/pyspark/profiler.pyi index d6a216b7f2fa..85aa6a248036 100644 --- a/python/pyspark/profiler.pyi +++ b/python/pyspark/profiler.pyi @@ -25,17 +25,24 @@ from pyspark.context import SparkContext class ProfilerCollector: profiler_cls: Type[Profiler] + udf_profiler_cls: Type[Profiler] profile_dump_path: Optional[str] profilers: List[Tuple[int, Profiler, bool]] - def __init__(self, profiler_cls: Type[Profiler], dump_path: Optional[str] = ...) -> None: ... + def __init__( + self, + profiler_cls: Type[Profiler], + udf_profiler_cls: Type[Profiler], + dump_path: Optional[str] = ..., + ) -> None: ... def new_profiler(self, ctx: SparkContext) -> Profiler: ... + def new_udf_profiler(self, ctx: SparkContext) -> Profiler: ... def add_profiler(self, id: int, profiler: Profiler) -> None: ... def dump_profiles(self, path: str) -> None: ... def show_profiles(self) -> None: ... class Profiler: def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[[], Any]) -> None: ... + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... def stats(self) -> pstats.Stats: ... def show(self, id: int) -> None: ... def dump(self, id: int, path: str) -> None: ... @@ -50,5 +57,9 @@ class PStatsParam(AccumulatorParam): class BasicProfiler(Profiler): def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[[], Any]) -> None: ... + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... def stats(self) -> pstats.Stats: ... + +class UDFBasicProfiler(BasicProfiler): + def show(self, id: int) -> None: ... + def dump(self, id: int, path: str) -> None: ... diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b997932c807b..2452d6923704 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ CartesianDeserializer, CloudPickleSerializer, PairDeserializer, - PickleSerializer, + CPickleSerializer, pack_long, read_int, write_int, @@ -259,7 +259,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(CPickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -270,7 +270,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.partitioner = None def _pickled(self): - return self._reserialize(AutoBatchedSerializer(PickleSerializer())) + return self._reserialize(AutoBatchedSerializer(CPickleSerializer())) def id(self): """ @@ -1841,7 +1841,7 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): def saveAsPickleFile(self, path, batchSize=10): """ Save this RDD as a SequenceFile of serialized objects. The serializer - used is :class:`pyspark.serializers.PickleSerializer`, default batch size + used is :class:`pyspark.serializers.CPickleSerializer`, default batch size is 10. Examples @@ -1854,9 +1854,9 @@ def saveAsPickleFile(self, path, batchSize=10): ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: - ser = AutoBatchedSerializer(PickleSerializer()) + ser = AutoBatchedSerializer(CPickleSerializer()) else: - ser = BatchedSerializer(PickleSerializer(), batchSize) + ser = BatchedSerializer(CPickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path, compressionCodecClass=None): @@ -2520,7 +2520,7 @@ def coalesce(self, numPartitions, shuffle=False): # Decrease the batch size in order to distribute evenly the elements across output # partitions. Otherwise, repartition will possibly produce highly skewed partitions. batchSize = min(10, self.ctx._batchSize or 1024) - ser = BatchedSerializer(PickleSerializer(), batchSize) + ser = BatchedSerializer(CPickleSerializer(), batchSize) selfCopy = self._reserialize(ser) jrdd_deserializer = selfCopy._jrdd_deserializer jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) @@ -2551,7 +2551,7 @@ def get_batch_size(ser): return 1 # not batched def batch_as(rdd, batchSize): - return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize)) + return rdd._reserialize(BatchedSerializer(CPickleSerializer(), batchSize)) my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 06eed0a3bc41..766ea64d905d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -19,7 +19,7 @@ PySpark supports custom serializers for transferring data; this can improve performance. -By default, PySpark uses :class:`PickleSerializer` to serialize objects using Python's +By default, PySpark uses :class:`CloudPickleSerializer` to serialize objects using Python's `cPickle` serializer, which can serialize nearly any Python object. Other serializers, like :class:`MarshalSerializer`, support fewer datatypes but can be faster. @@ -69,7 +69,13 @@ from pyspark.util import print_exec # type: ignore -__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] +__all__ = [ + "PickleSerializer", + "CPickleSerializer", + "CloudPickleSerializer", + "MarshalSerializer", + "UTF8Deserializer", +] class SpecialLengths(object): @@ -344,78 +350,81 @@ def dumps(self, obj): return obj -# Hack namedtuple, make it picklable - -__cls = {} # type: ignore - - -def _restore(name, fields, value): - """Restore an object of namedtuple""" - k = (name, fields) - cls = __cls.get(k) - if cls is None: - cls = collections.namedtuple(name, fields) - __cls[k] = cls - return cls(*value) - - -def _hack_namedtuple(cls): - """Make class generated by namedtuple picklable""" - name = cls.__name__ - fields = cls._fields - - def __reduce__(self): - return (_restore, (name, fields, tuple(self))) - - cls.__reduce__ = __reduce__ - cls._is_namedtuple_ = True - return cls - - -def _hijack_namedtuple(): - """Hack namedtuple() to make it picklable""" - # hijack only one time - if hasattr(collections.namedtuple, "__hijack"): - return - - global _old_namedtuple # or it will put in closure - global _old_namedtuple_kwdefaults # or it will put in closure too - - def _copy_func(f): - return types.FunctionType( - f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__ - ) - - _old_namedtuple = _copy_func(collections.namedtuple) - _old_namedtuple_kwdefaults = collections.namedtuple.__kwdefaults__ - - def namedtuple(*args, **kwargs): - for k, v in _old_namedtuple_kwdefaults.items(): - kwargs[k] = kwargs.get(k, v) - cls = _old_namedtuple(*args, **kwargs) - return _hack_namedtuple(cls) - - # replace namedtuple with the new one - collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults - collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.__code__ = namedtuple.__code__ - collections.namedtuple.__hijack = 1 - - # hack the cls already generated by namedtuple. - # Those created in other modules can be pickled as normal, - # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.items(): - if ( - type(o) is type - and o.__base__ is tuple - and hasattr(o, "_fields") - and "__reduce__" not in o.__dict__ - ): - _hack_namedtuple(o) # hack inplace - +if sys.version_info < (3, 8): + # Hack namedtuple, make it picklable. + # For Python 3.8+, we use CPickle-based cloudpickle. + # For Python 3.7 and below, we use legacy build-in CPickle which + # requires namedtuple hack. + # The whole hack here should be removed once we drop Python 3.7. + + __cls = {} # type: ignore + + def _restore(name, fields, value): + """Restore an object of namedtuple""" + k = (name, fields) + cls = __cls.get(k) + if cls is None: + cls = collections.namedtuple(name, fields) + __cls[k] = cls + return cls(*value) + + def _hack_namedtuple(cls): + """Make class generated by namedtuple picklable""" + name = cls.__name__ + fields = cls._fields + + def __reduce__(self): + return (_restore, (name, fields, tuple(self))) + + cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True + return cls + + def _hijack_namedtuple(): + """Hack namedtuple() to make it picklable""" + # hijack only one time + if hasattr(collections.namedtuple, "__hijack"): + return -_hijack_namedtuple() + global _old_namedtuple # or it will put in closure + global _old_namedtuple_kwdefaults # or it will put in closure too + + def _copy_func(f): + return types.FunctionType( + f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__ + ) + + _old_namedtuple = _copy_func(collections.namedtuple) + _old_namedtuple_kwdefaults = collections.namedtuple.__kwdefaults__ + + def namedtuple(*args, **kwargs): + for k, v in _old_namedtuple_kwdefaults.items(): + kwargs[k] = kwargs.get(k, v) + cls = _old_namedtuple(*args, **kwargs) + return _hack_namedtuple(cls) + + # replace namedtuple with the new one + collections.namedtuple.__globals__[ + "_old_namedtuple_kwdefaults" + ] = _old_namedtuple_kwdefaults + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ + collections.namedtuple.__hijack = 1 + + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, + # so only hack those in __main__ module + for n, o in sys.modules["__main__"].__dict__.items(): + if ( + type(o) is type + and o.__base__ is tuple + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__ + ): + _hack_namedtuple(o) # hack inplace + + _hijack_namedtuple() class PickleSerializer(FramedSerializer): @@ -436,7 +445,7 @@ def loads(self, obj, encoding="bytes"): return pickle.loads(obj, encoding=encoding) -class CloudPickleSerializer(PickleSerializer): +class CloudPickleSerializer(FramedSerializer): def dumps(self, obj): try: return cloudpickle.dumps(obj, pickle_protocol) @@ -451,6 +460,15 @@ def dumps(self, obj): print_exec(sys.stderr) raise pickle.PicklingError(msg) + def loads(self, obj, encoding="bytes"): + return cloudpickle.loads(obj, encoding=encoding) + + +if sys.version_info < (3, 8): + CPickleSerializer = PickleSerializer +else: + CPickleSerializer = CloudPickleSerializer + class MarshalSerializer(FramedSerializer): @@ -459,7 +477,7 @@ class MarshalSerializer(FramedSerializer): http://docs.python.org/2/library/marshal.html - This serializer is faster than PickleSerializer but supports fewer datatypes. + This serializer is faster than CloudPickleSerializer but supports fewer datatypes. """ def dumps(self, obj): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 9dbe314d29b2..bd455667f36f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,7 @@ import heapq from pyspark.serializers import ( BatchedSerializer, - PickleSerializer, + CPickleSerializer, FlattenedValuesSerializer, CompressedSerializer, AutoBatchedSerializer, @@ -140,8 +140,8 @@ def items(self): def _compressed_serializer(self, serializer=None): - # always use PickleSerializer to simplify implementation - ser = PickleSerializer() + # always use CPickleSerializer to simplify implementation + ser = CPickleSerializer() return AutoBatchedSerializer(CompressedSerializer(ser)) @@ -609,7 +609,7 @@ def _open_file(self): os.makedirs(d) p = os.path.join(d, str(id(self))) self._file = open(p, "w+b", 65536) - self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + self._ser = BatchedSerializer(CompressedSerializer(CPickleSerializer()), 1024) os.unlink(p) def __del__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 337cad534faa..160e7c3841af 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -47,7 +47,7 @@ _load_from_socket, _local_iterator_from_socket, ) -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column @@ -121,7 +121,7 @@ def rdd(self) -> "RDD[Row]": """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.""" if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(CPickleSerializer())) return self._lazy_rdd @property # type: ignore[misc] @@ -592,7 +592,7 @@ def _repr_html_(self) -> Optional[str]: max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate(), # type: ignore[attr-defined] ) - rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + rows = list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) head = rows[0] row_data = rows[1:] has_more_data = len(row_data) > max_num_rows @@ -769,7 +769,7 @@ def collect(self) -> List[Row]: """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectToPython() - return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: """ @@ -792,7 +792,7 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator(prefetchPartitions) - return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) + return _local_iterator_from_socket(sock_info, BatchedSerializer(CPickleSerializer())) def limit(self, num: int) -> "DataFrame": """Limits the result count to the number specified. @@ -837,7 +837,7 @@ def tail(self, num: int) -> List[Row]: """ with SCCallSiteSync(self._sc): sock_info = self._jdf.tailToPython(num) - return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def foreach(self, f: Callable[[Row], None]) -> None: """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py index cc0db017c301..bc6202f85463 100644 --- a/python/pyspark/sql/pandas/utils.py +++ b/python/pyspark/sql/pandas/utils.py @@ -19,7 +19,7 @@ def require_minimum_pandas_version() -> None: """Raise ImportError if minimum version of Pandas is not installed""" # TODO(HyukjinKwon): Relocate and deduplicate the version specification. - minimum_pandas_version = "0.23.2" + minimum_pandas_version = "1.0.5" from distutils.version import LooseVersion diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 927554198743..f94b9c211504 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -291,7 +291,7 @@ def __init__( self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None, - options: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = {}, ): from pyspark.sql.context import SQLContext @@ -305,10 +305,7 @@ def __init__( ): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) - if options is not None: - for key, value in options.items(): - jsparkSession.sharedState().conf().set(key, value) + jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 53a098ce4985..74593d070004 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -1200,7 +1200,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt """ from pyspark.rdd import _wrap_function # type: ignore[attr-defined] - from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.taskcontext import TaskContext if callable(f): @@ -1268,7 +1268,7 @@ def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Itera func = func_with_open_process_close # type: ignore[assignment] - serializer = AutoBatchedSerializer(PickleSerializer()) + serializer = AutoBatchedSerializer(CPickleSerializer()) wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) jForeachWriter = self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( # type: ignore[attr-defined] wrapped_func, self._df._jdf.schema() diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index eb23b68ccf49..06771fac896b 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -289,18 +289,23 @@ def test_another_spark_session(self): if session2 is not None: session2.stop() - def test_create_spark_context_first_and_copy_options_to_sharedState(self): + def test_create_spark_context_with_initial_session_options(self): sc = None session = None try: conf = SparkConf().set("key1", "value1") sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf) session = ( - SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate() + SparkSession.builder.config("spark.sql.codegen.comments", "true") + .enableHiveSupport() + .getOrCreate() ) self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1") - self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2") + self.assertEqual( + session._jsparkSession.sharedState().conf().get("spark.sql.codegen.comments"), + "true", + ) self.assertEqual( session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"), "hive", diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py new file mode 100644 index 000000000000..27d945850940 --- /dev/null +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest +import os +import sys +from io import StringIO + +from pyspark import SparkConf, SparkContext +from pyspark.sql import SparkSession +from pyspark.sql.functions import udf +from pyspark.profiler import UDFBasicProfiler + + +class UDFProfilerTests(unittest.TestCase): + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext("local[4]", class_name, conf=conf) + self.spark = SparkSession.builder._sparkContext(self.sc).getOrCreate() + + def tearDown(self): + self.spark.stop() + sys.path = self._old_sys_path + + def test_udf_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(3, len(profilers)) + + old_stdout = sys.stdout + try: + sys.stdout = io = StringIO() + self.sc.show_profiles() + finally: + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + + for i, udf_name in enumerate(["add1", "add2", "add1"]): + id, profiler, _ = profilers[i] + with self.subTest(id=id, udf_name=udf_name): + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue(udf_name in func_names) + + self.assertTrue(udf_name in io.getvalue()) + self.assertTrue("udf_%d.pstats" % id in os.listdir(d)) + + def test_custom_udf_profiler(self): + class TestCustomProfiler(UDFBasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.udf_profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(3, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + @udf + def add1(x): + return x + 1 + + @udf + def add2(x): + return x + 2 + + df = self.spark.range(10) + df.select(add1("id"), add2("id"), add1("id")).collect() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_udf_profiler import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 886451a7cccb..0b47f8796acf 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -18,12 +18,14 @@ User-defined function related classes and functions """ import functools +import inspect import sys from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import SparkContext +from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType # type: ignore[attr-defined] from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import ( @@ -209,16 +211,16 @@ def _judf(self) -> JavaObject: # This is unlikely, doesn't affect correctness, # and should have a minimal performance impact. if self._judf_placeholder is None: - self._judf_placeholder = self._create_judf() + self._judf_placeholder = self._create_judf(self.func) return self._judf_placeholder - def _create_judf(self) -> JavaObject: + def _create_judf(self, func: Callable[..., Any]) -> JavaObject: from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( # type: ignore[attr-defined] self._name, wrapped_func, jdt, self.evalType, self.deterministic @@ -226,9 +228,29 @@ def _create_judf(self) -> JavaObject: return judf def __call__(self, *cols: "ColumnOrName") -> Column: - judf = self._judf sc = SparkContext._active_spark_context # type: ignore[attr-defined] - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + + profiler: Optional[Profiler] = None + if sc.profiler_collector: + f = self.func + profiler = sc.profiler_collector.new_udf_profiler(sc) + + @functools.wraps(f) + def func(*args: Any, **kwargs: Any) -> Any: + assert profiler is not None + return profiler.profile(f, *args, **kwargs) + + func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] + + judf = self._create_judf(func) + else: + judf = self._judf + + jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) + if profiler is not None: + id = jPythonUDF.expr().resultId().id() + sc.profiler_collector.add_profiler(id, profiler) + return Column(jPythonUDF) # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index a2c9ce90e943..51c1149080bd 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -29,7 +29,7 @@ from pyspark.serializers import ( CloudPickleSerializer, BatchedSerializer, - PickleSerializer, + CPickleSerializer, MarshalSerializer, UTF8Deserializer, NoOpSerializer, @@ -446,7 +446,7 @@ def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + a = a._reserialize(BatchedSerializer(CPickleSerializer(), 2)) b = b._reserialize(MarshalSerializer()) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) # regression test for SPARK-4841 diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index 3a9e14dd16a3..019f5279bc57 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -29,7 +29,7 @@ PairDeserializer, FlattenedValuesSerializer, CartesianDeserializer, - PickleSerializer, + CPickleSerializer, UTF8Deserializer, MarshalSerializer, ) @@ -46,15 +46,13 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from pickle import dumps, loads + from pyspark.cloudpickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) p3 = P2(1, 3) self.assertEqual(p1, p3) @@ -132,7 +130,7 @@ def foo(): ser.dumps(foo) def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) + ser = CompressedSerializer(CPickleSerializer()) from io import BytesIO as StringIO io = StringIO() @@ -147,15 +145,15 @@ def test_compressed_serializer(self): def test_hash_serializer(self): hash(NoOpSerializer()) hash(UTF8Deserializer()) - hash(PickleSerializer()) + hash(CPickleSerializer()) hash(MarshalSerializer()) hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) + hash(BatchedSerializer(CPickleSerializer())) hash(AutoBatchedSerializer(MarshalSerializer())) hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) + hash(CompressedSerializer(CPickleSerializer())) + hash(FlattenedValuesSerializer(CPickleSerializer())) @unittest.skipIf(not have_scipy, "SciPy not installed") diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py index 805a47dd1e3a..5d69b67fc3e0 100644 --- a/python/pyspark/tests/test_shuffle.py +++ b/python/pyspark/tests/test_shuffle.py @@ -19,7 +19,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark import shuffle, CPickleSerializer, SparkConf, SparkContext from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter @@ -80,7 +80,7 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) - ser = PickleSerializer() + ser = CPickleSerializer() l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) for k, vs in l: self.assertEqual(k, len(vs)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c2200b20fe25..1935e27d6636 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,7 +50,7 @@ read_int, SpecialLengths, UTF8Deserializer, - PickleSerializer, + CPickleSerializer, BatchedSerializer, ) from pyspark.sql.pandas.serializers import ( @@ -63,7 +63,7 @@ from pyspark.util import fail_on_stopiteration, try_simplify_traceback # type: ignore from pyspark import shuffle -pickleSer = PickleSerializer() +pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -367,7 +367,7 @@ def read_udfs(pickleSer, infile, eval_type): timezone, safecheck, assign_cols_by_name, df_for_struct ) else: - ser = BatchedSerializer(PickleSerializer(), 100) + ser = BatchedSerializer(CPickleSerializer(), 100) num_udfs = read_int(infile) diff --git a/python/setup.py b/python/setup.py index 4507a2686e2c..174995d4aec4 100755 --- a/python/setup.py +++ b/python/setup.py @@ -111,7 +111,7 @@ def _supports_symlinks(): # For Arrow, you should also check ./pom.xml and ensure there are no breaking changes in the # binary format protocol with the Java version, see ARROW_HOME/format/* for specifications. # Also don't forget to update python/docs/source/getting_started/install.rst. -_minimum_pandas_version = "0.23.2" +_minimum_pandas_version = "1.0.5" _minimum_pyarrow_version = "1.0.0" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala index 9e3d23c0063c..192b5993efe0 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -24,12 +24,21 @@ import io.fabric8.kubernetes.client.KubernetesClient import scala.collection.JavaConverters._ import org.apache.spark.SparkConf +import org.apache.spark.annotation.{DeveloperApi, Since, Stable} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.util.{ThreadUtils, Utils} -private[spark] class ExecutorPodsPollingSnapshotSource( +/** + * :: DeveloperApi :: + * + * A class used for polling K8s executor pods by ExternalClusterManagers. + * @since 3.1.3 + */ +@Stable +@DeveloperApi +class ExecutorPodsPollingSnapshotSource( conf: SparkConf, kubernetesClient: KubernetesClient, snapshotsStore: ExecutorPodsSnapshotsStore, @@ -39,6 +48,7 @@ private[spark] class ExecutorPodsPollingSnapshotSource( private var pollingFuture: Future[_] = _ + @Since("3.1.3") def start(applicationId: String): Unit = { require(pollingFuture == null, "Cannot start polling more than once.") logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") @@ -46,6 +56,7 @@ private[spark] class ExecutorPodsPollingSnapshotSource( new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) } + @Since("3.1.3") def stop(): Unit = { if (pollingFuture != null) { pollingFuture.cancel(true) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala index 762878cbacac..06d942eb5b36 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -22,16 +22,27 @@ import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClient, Watcher, WatcherException} import io.fabric8.kubernetes.client.Watcher.Action +import org.apache.spark.annotation.{DeveloperApi, Since, Stable} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[spark] class ExecutorPodsWatchSnapshotSource( +/** + * :: DeveloperApi :: + * + * A class used for watching K8s executor pods by ExternalClusterManagers. + * + * @since 3.1.3 + */ +@Stable +@DeveloperApi +class ExecutorPodsWatchSnapshotSource( snapshotsStore: ExecutorPodsSnapshotsStore, kubernetesClient: KubernetesClient) extends Logging { private var watchConnection: Closeable = _ + @Since("3.1.3") def start(applicationId: String): Unit = { require(watchConnection == null, "Cannot start the watcher twice.") logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + @@ -42,6 +53,7 @@ private[spark] class ExecutorPodsWatchSnapshotSource( .watch(new ExecutorPodsWatcher()) } + @Since("3.1.3") def stop(): Unit = { if (watchConnection != null) { Utils.tryLogNonFatalError { diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 index f9ab64e94a54..96dd6c996b34 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 @@ -51,6 +51,7 @@ COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark +ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/ WORKDIR /opt/spark/work-dir RUN chmod g+w /opt/spark/work-dir diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 8b14b7ecb1e2..8d10985b4f26 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -36,7 +36,7 @@ INCLUDE_TAGS="k8s" EXCLUDE_TAGS= JAVA_VERSION="8" BUILD_DEPENDENCIES_MVN_FLAG="-am" -HADOOP_PROFILE="hadoop-3" +HADOOP_PROFILE="hadoop-3.2" MVN="$TEST_ROOT_DIR/build/mvn" SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/null\ diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 82da4f583303..d281e38ebf05 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -192,7 +192,7 @@ - hadoop-3 + hadoop-3.2 true diff --git a/resource-managers/mesos/src/test/resources/log4j.properties b/resource-managers/mesos/src/test/resources/log4j.properties new file mode 100644 index 000000000000..9ec68901eedd --- /dev/null +++ b/resource-managers/mesos/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.sparkproject.jetty=WARN diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 2d7b7f34d79a..287f636b9e9c 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -70,7 +70,7 @@ - hadoop-3 + hadoop-3.2 true diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7787e2fc9200..e6136fc54fd0 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -169,7 +169,6 @@ private[spark] class Client( def submitApplication(): ApplicationId = { ResourceRequestHelper.validateResources(sparkConf) - var appId: ApplicationId = null try { launcherBackend.connect() yarnClient.init(hadoopConf) @@ -181,7 +180,7 @@ private[spark] class Client( // Get a new application from our RM val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() - appId = newAppResponse.getApplicationId() + this.appId = newAppResponse.getApplicationId() // The app staging dir based on the STAGING_DIR configuration if configured // otherwise based on the users home directory. @@ -207,8 +206,7 @@ private[spark] class Client( yarnClient.submitApplication(appContext) launcherBackend.setAppId(appId.toString) reportLauncherState(SparkAppHandle.State.SUBMITTED) - - appId + this.appId } catch { case e: Throwable => if (stagingDirPath != null) { @@ -915,7 +913,6 @@ private[spark] class Client( private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) : ContainerLaunchContext = { logInfo("Setting up container launch context for our AM") - val appId = newAppResponse.getApplicationId val pySparkArchives = if (sparkConf.get(IS_PYTHON_APP)) { findPySparkArchives() @@ -971,7 +968,7 @@ private[spark] class Client( if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts) - .map(Utils.substituteAppId(_, appId.toString)) + .map(Utils.substituteAppId(_, this.appId.toString)) .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), @@ -996,7 +993,7 @@ private[spark] class Client( throw new SparkException(msg) } javaOpts ++= Utils.splitCommandString(opts) - .map(Utils.substituteAppId(_, appId.toString)) + .map(Utils.substituteAppId(_, this.appId.toString)) .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => @@ -1269,7 +1266,7 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - this.appId = submitApplication() + submitApplication() if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 073fa283100b..a1e7b4bb1613 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -122,7 +122,7 @@ statement : query #statementDefault | ctes? dmlStatementNoWith #dmlStatement | USE multipartIdentifier #use - | USE NAMESPACE multipartIdentifier #useNamespace + | USE namespace multipartIdentifier #useNamespace | SET CATALOG (identifier | STRING) #setCatalog | CREATE namespace (IF NOT EXISTS)? multipartIdentifier (commentSpec | @@ -134,7 +134,7 @@ statement SET locationSpec #setNamespaceLocation | DROP namespace (IF EXISTS)? multipartIdentifier (RESTRICT | CASCADE)? #dropNamespace - | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=STRING)? #showNamespaces | createTableHeader ('(' colTypeList ')')? tableProvider? createTableClauses @@ -228,7 +228,7 @@ statement | SHOW identifier? FUNCTIONS (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable - | SHOW CURRENT NAMESPACE #showCurrentNamespace + | SHOW CURRENT namespace #showCurrentNamespace | SHOW CATALOGS (LIKE? pattern=STRING)? #showCatalogs | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) namespace EXTENDED? @@ -382,6 +382,12 @@ namespace | SCHEMA ; +namespaces + : NAMESPACES + | DATABASES + | SCHEMAS + ; + describeFuncName : qualifiedName | STRING @@ -1230,6 +1236,7 @@ ansiNonReserved | ROW | ROWS | SCHEMA + | SCHEMAS | SECOND | SEMI | SEPARATED @@ -1501,6 +1508,7 @@ nonReserved | ROW | ROWS | SCHEMA + | SCHEMAS | SECOND | SELECT | SEPARATED @@ -1628,7 +1636,7 @@ CURRENT_USER: 'CURRENT_USER'; DAY: 'DAY'; DATA: 'DATA'; DATABASE: 'DATABASE'; -DATABASES: 'DATABASES' | 'SCHEMAS'; +DATABASES: 'DATABASES'; DBPROPERTIES: 'DBPROPERTIES'; DEFINED: 'DEFINED'; DELETE: 'DELETE'; @@ -1774,6 +1782,7 @@ ROW: 'ROW'; ROWS: 'ROWS'; SECOND: 'SECOND'; SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; SELECT: 'SELECT'; SEMI: 'SEMI'; SEPARATED: 'SEPARATED'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 34f07b12b366..359bc0017bcb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -68,6 +68,16 @@ public Table loadTable(Identifier ident) throws NoSuchTableException { return asTableCatalog().loadTable(ident); } + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + return asTableCatalog().loadTable(ident, timestamp); + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + return asTableCatalog().loadTable(ident, version); + } + @Override public void invalidateTable(Identifier ident) { asTableCatalog().invalidateTable(ident); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java new file mode 100644 index 000000000000..777693938c4e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in for input partitions whose records are clustered on the same set of partition keys + * (provided via {@link SupportsReportPartitioning}, see below). Data sources can opt-in to + * implement this interface for the partitions they report to Spark, which will use the + * information to avoid data shuffling in certain scenarios, such as join, aggregate, etc. Note + * that Spark requires ALL input partitions to implement this interface, otherwise it can't take + * advantage of it. + *

+ * This interface should be used in combination with {@link SupportsReportPartitioning}, which + * allows data sources to report distribution and ordering spec to Spark. In particular, Spark + * expects data sources to report + * {@link org.apache.spark.sql.connector.distributions.ClusteredDistribution} whenever its input + * partitions implement this interface. Spark derives partition keys spec (e.g., column names, + * transforms) from the distribution, and partition values from the input partitions. + *

+ * It is implementor's responsibility to ensure that when an input partition implements this + * interface, its records all have the same value for the partition keys. Spark doesn't check + * this property. + * + * @see org.apache.spark.sql.connector.read.SupportsReportPartitioning + * @see org.apache.spark.sql.connector.read.partitioning.Partitioning + */ +public interface HasPartitionKey extends InputPartition { + /** + * Returns the value of the partition key(s) associated to this partition. An input partition + * implementing this interface needs to ensure that all its records have the same value for the + * partition keys. Note that the value is after partition transform has been applied, if there + * is any. + */ + InternalRow partitionKey(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index debc13b953ee..267c2cc8868a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -75,7 +75,7 @@ import org.apache.spark.sql.types._ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = WidenSetOperationTypes :: - CombinedTypeCoercionRule( + new AnsiCombinedTypeCoercionRule( InConversion :: PromoteStringLiterals :: DecimalPrecision :: @@ -304,4 +304,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { s.copy(left = newLeft, right = newRight) } } + + // This is for generating a new rule id, so that we can run both default and Ansi + // type coercion rules against one logical plan. + class AnsiCombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends + CombinedTypeCoercionRule(rules) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5bf37a2944cb..491d52588f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -47,6 +48,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError") + protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } @@ -165,14 +168,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - val exprs = operator match { - // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias - // feature. We should check errors in `aggregateExpressions` first. - case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions - case _ => operator.expressions - } - - exprs.foreach(_.foreachUp { + getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => val missingCol = a.sql val candidates = operator.inputSet.toSeq.map(_.qualifiedName) @@ -189,8 +185,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => + e.setTagValue(DATA_TYPE_MISMATCH_ERROR, true) e.failAnalysis( - s"cannot resolve '${e.sql}' due to data type mismatch: $message") + s"cannot resolve '${e.sql}' due to data type mismatch: $message" + + extraHintForAnsiTypeCoercionExpression(operator)) } case c: Cast if !c.resolved => @@ -424,27 +422,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { |the ${ordinalNumber(ti + 1)} table has ${child.output.length} columns """.stripMargin.replace("\n", " ").trim()) } - val isUnion = operator.isInstanceOf[Union] - val dataTypesAreCompatibleFn = if (isUnion) { - (dt1: DataType, dt2: DataType) => - !DataType.equalsStructurally(dt1, dt2, true) - } else { - // SPARK-18058: we shall not care about the nullability of columns - (dt1: DataType, dt2: DataType) => - TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty - } + val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(operator) // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns if (dataTypesAreCompatibleFn(dt1, dt2)) { - failAnalysis( + val errorMessage = s""" |${operator.nodeName} can only be performed on tables with the compatible |column types. The ${ordinalNumber(ci)} column of the |${ordinalNumber(ti + 1)} table is ${dt1.catalogString} type which is not |compatible with ${dt2.catalogString} at same column of first table - """.stripMargin.replace("\n", " ").trim()) + """.stripMargin.replace("\n", " ").trim() + failAnalysis(errorMessage + extraHintForAnsiTypeCoercionPlan(operator)) } } } @@ -593,6 +584,86 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { plan.setAnalyzed() } + private def getAllExpressions(plan: LogicalPlan): Seq[Expression] = { + plan match { + // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias + // feature. We should check errors in `aggregateExpressions` first. + case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions + case _ => plan.expressions + } + } + + private def getDataTypesAreCompatibleFn(plan: LogicalPlan): (DataType, DataType) => Boolean = { + val isUnion = plan.isInstanceOf[Union] + if (isUnion) { + (dt1: DataType, dt2: DataType) => + !DataType.equalsStructurally(dt1, dt2, true) + } else { + // SPARK-18058: we shall not care about the nullability of columns + (dt1: DataType, dt2: DataType) => + TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty + } + } + + private def getDefaultTypeCoercionPlan(plan: LogicalPlan): LogicalPlan = + TypeCoercion.typeCoercionRules.foldLeft(plan) { case (p, rule) => rule(p) } + + private def extraHintMessage(issueFixedIfAnsiOff: Boolean): String = { + if (issueFixedIfAnsiOff) { + "\nTo fix the error, you might need to add explicit type casts. If necessary set " + + s"${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + } else { + "" + } + } + + private def extraHintForAnsiTypeCoercionExpression(plan: LogicalPlan): String = { + if (!SQLConf.get.ansiEnabled) { + "" + } else { + val nonAnsiPlan = getDefaultTypeCoercionPlan(plan) + var issueFixedIfAnsiOff = true + getAllExpressions(nonAnsiPlan).foreach(_.foreachUp { + case e: Expression if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).contains(true) && + e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(_) => + issueFixedIfAnsiOff = false + } + + case _ => + }) + extraHintMessage(issueFixedIfAnsiOff) + } + } + + private def extraHintForAnsiTypeCoercionPlan(plan: LogicalPlan): String = { + if (!SQLConf.get.ansiEnabled) { + "" + } else { + val nonAnsiPlan = getDefaultTypeCoercionPlan(plan) + var issueFixedIfAnsiOff = true + nonAnsiPlan match { + case _: Union | _: SetOperation if nonAnsiPlan.children.length > 1 => + def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) + + val ref = dataTypes(nonAnsiPlan.children.head) + val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(nonAnsiPlan) + nonAnsiPlan.children.tail.zipWithIndex.foreach { case (child, ti) => + // Check if the data types match. + dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => + if (dataTypesAreCompatibleFn(dt1, dt2)) { + issueFixedIfAnsiOff = false + } + } + } + + case _ => + } + extraHintMessage(issueFixedIfAnsiOff) + } + } + /** * Validates subquery expressions in the plan. Upon failure, returns an user facing error. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index efc1ab2cd0e1..83d7b932a8bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -37,17 +37,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case UnresolvedDBObjectName(CatalogAndIdentifier(catalog, identifier), _) => ResolvedDBObjectName(catalog, identifier.namespace :+ identifier.name()) - case c @ CreateTableStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => - CreateV2Table( - catalog.asTableCatalog, - tbl.asIdentifier, - c.tableSchema, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c), - ignoreIfExists = c.ifNotExists) - case c @ CreateTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _, _) => CreateTableAsSelect( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala index 2d73b99d2712..cbb6e8bb06a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression, RuntimeReplaceable, SubqueryExpression, Unevaluable} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.TimestampType @@ -40,10 +40,17 @@ object TimeTravelSpec { if (!AnsiCast.canCast(ts.dataType, TimestampType)) { throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) } + val tsToEval = ts.transform { + case r: RuntimeReplaceable => r.child + case _: Unevaluable => + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + case e if !e.deterministic => + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + } val tz = Some(conf.sessionLocalTimeZone) // Set `ansiEnabled` to false, so that it can return null for invalid input and we can provide // better error message. - val value = Cast(ts, TimestampType, tz, ansiEnabled = false).eval() + val value = Cast(tsToEval, TimestampType, tz, ansiEnabled = false).eval() if (value == null) { throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 506667461ec0..82fba937617d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -170,7 +170,7 @@ abstract class TypeCoercionBase { * Type coercion rule that combines multiple type coercion rules and applies them in a single tree * traversal. */ - case class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { + class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { override def transform: PartialFunction[Expression, Expression] = { val transforms = rules.map(_.transform) Function.unlift { e: Expression => @@ -795,7 +795,7 @@ object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = WidenSetOperationTypes :: - CombinedTypeCoercionRule( + new CombinedTypeCoercionRule( InConversion :: PromoteStrings :: DecimalPrecision :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 221f5ae73673..3b501d686cfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -292,6 +292,16 @@ abstract class Expression extends TreeNode[Expression] { override def simpleStringWithNodeId(): String = { throw QueryExecutionErrors.simpleStringWithNodeIdUnsupportedError(nodeName) } + + protected def typeSuffix = + if (resolved) { + dataType match { + case LongType => "L" + case _ => "" + } + } else { + "" + } } @@ -387,6 +397,7 @@ trait NonSQLExpression extends Expression { transform { case a: Attribute => new PrettyAttribute(a) case a: Alias => PrettyAttribute(a.sql, a.dataType) + case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) }.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 80e235286970..6b9017a01db3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -62,7 +62,7 @@ case class PythonUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - override def toString: String = s"$name(${children.mkString(", ")})" + override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix" final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF) @@ -80,3 +80,21 @@ case class PythonUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF = copy(children = newChildren) } + +/** + * A place holder used when printing expressions without debugging information such as the + * result id. + */ +case class PrettyPythonUDF( + name: String, + dataType: DataType, + children: Seq[Expression]) + extends Expression with Unevaluable with NonSQLExpression { + + override def toString: String = s"$name(${children.mkString(", ")})" + + override def nullable: Boolean = true + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PrettyPythonUDF = copy(children = newChildren) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 71f193e51074..5cc81244c8d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -101,16 +101,6 @@ trait NamedExpression extends Expression { /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression - - protected def typeSuffix = - if (resolved) { - dataType match { - case LongType => "L" - case _ => "" - } - } else { - "" - } } abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a78f08ac8f4c..2c32ba7db551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3152,10 +3152,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a [[ShowNamespaces]] command. */ override def visitShowNamespaces(ctx: ShowNamespacesContext): LogicalPlan = withOrigin(ctx) { - if (ctx.DATABASES != null && ctx.multipartIdentifier != null) { - throw QueryParsingErrors.fromOrInNotAllowedInShowDatabasesError(ctx) - } - val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier) ShowNamespaces( UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])), @@ -3414,7 +3410,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a table, returning a [[CreateTableStatement]] logical plan. + * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelectStatement]] logical plan. * * Expected format: * {{{ @@ -3481,9 +3477,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case _ => // Note: table schema includes both the table columns list and the partition columns // with data type. + val tableSpec = TableSpec(bucketSpec, properties, provider, options, location, comment, + serdeInfo, external) val schema = StructType(columns ++ partCols) - CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider, - options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists) + CreateTable( + UnresolvedDBObjectName(table, isNamespace = false), + schema, partitioning, tableSpec, ignoreIfExists = ifNotExists) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7c31a0091811..4aa7bf1c4f9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,7 +41,8 @@ abstract class LogicalPlan def metadataOutput: Seq[Attribute] = children.flatMap(_.metadataOutput) /** Returns true if this subtree has data from a streaming data source. */ - def isStreaming: Boolean = children.exists(_.isStreaming) + def isStreaming: Boolean = _isStreaming + private[this] lazy val _isStreaming = children.exists(_.isStreaming) override def verboseStringWithSuffix(maxFields: Int): String = { super.verboseString(maxFields) + statsCache.map(", " + _.toString).getOrElse("") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index ccc4e190ba51..a6ed304e7155 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -123,25 +123,6 @@ object SerdeInfo { } } -/** - * A CREATE TABLE command, as parsed from SQL. - * - * This is a metadata-only command and is not used to write data to the created table. - */ -case class CreateTableStatement( - tableName: Seq[String], - tableSchema: StructType, - partitioning: Seq[Transform], - bucketSpec: Option[BucketSpec], - properties: Map[String, String], - provider: Option[String], - options: Map[String, String], - location: Option[String], - comment: Option[String], - serde: Option[SerdeInfo], - external: Boolean, - ifNotExists: Boolean) extends LeafParsedStatement - /** * A CREATE TABLE AS SELECT command, as parsed from SQL. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala index 5ae8d69f4b99..091955b6b02c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala @@ -50,13 +50,19 @@ object UnionEstimation { case TimestampType => (a: Any, b: Any) => TimestampType.ordering.lt(a.asInstanceOf[TimestampType.InternalType], b.asInstanceOf[TimestampType.InternalType]) + case TimestampNTZType => (a: Any, b: Any) => + TimestampNTZType.ordering.lt(a.asInstanceOf[TimestampNTZType.InternalType], + b.asInstanceOf[TimestampNTZType.InternalType]) + case i: AnsiIntervalType => (a: Any, b: Any) => + i.ordering.lt(a.asInstanceOf[i.InternalType], b.asInstanceOf[i.InternalType]) case _ => throw new IllegalStateException(s"Unsupported data type: ${dt.catalogString}") } private def isTypeSupported(dt: DataType): Boolean = dt match { case ByteType | IntegerType | ShortType | FloatType | LongType | - DoubleType | DateType | _: DecimalType | TimestampType => true + DoubleType | DateType | _: DecimalType | TimestampType | TimestampNTZType | + _: AnsiIntervalType => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 4ed5d87aaf10..d39e28865da8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} @@ -193,13 +194,24 @@ trait V2CreateTablePlan extends LogicalPlan { /** * Create a new table with a v2 catalog. */ -case class CreateV2Table( - catalog: TableCatalog, - tableName: Identifier, +case class CreateTable( + name: LogicalPlan, tableSchema: StructType, partitioning: Seq[Transform], - properties: Map[String, String], - ignoreIfExists: Boolean) extends LeafCommand with V2CreateTablePlan { + tableSpec: TableSpec, + ignoreIfExists: Boolean) extends UnaryCommand with V2CreateTablePlan { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + + override def child: LogicalPlan = name + + override def tableName: Identifier = { + assert(child.resolved) + child.asInstanceOf[ResolvedDBObjectName].nameParts.asIdentifier + } + + override protected def withNewChildInternal(newChild: LogicalPlan): V2CreateTablePlan = + copy(name = newChild) + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } @@ -1090,3 +1102,13 @@ case class DropIndex( override protected def withNewChildInternal(newChild: LogicalPlan): DropIndex = copy(table = newChild) } + +case class TableSpec( + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: Option[String], + options: Map[String, String], + location: Option[String], + comment: Option[String], + serde: Option[SerdeInfo], + external: Boolean) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 5ec303d97fbd..4face494621b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -76,6 +76,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: "org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f3f6744720f2..3d62cf2b8342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.TableSpec import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.RuleId import org.apache.spark.sql.catalyst.rules.RuleIdCollection @@ -819,6 +820,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre redactMapString(map.asCaseSensitiveMap().asScala, maxFields) case map: Map[_, _] => redactMapString(map, maxFields) + case t: TableSpec => + t.copy(properties = Utils.redact(t.properties).toMap, + options = Utils.redact(t.options).toMap) :: Nil case table: CatalogTable => table.storage.serde match { case Some(serde) => table.identifier :: serde :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 33fe48d44dad..e26f397bb0b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -139,6 +139,7 @@ package object util extends Logging { PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType) case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) => PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType) + case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) } def quoteIdentifier(name: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 7c11463fc7c5..fabc73f9cf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -23,7 +23,7 @@ import java.util.Collections import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, TimeTravelSpec} -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} @@ -305,11 +305,6 @@ private[sql] object CatalogV2Util { catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) } - def convertTableProperties(c: CreateTableStatement): Map[String, String] = { - convertTableProperties( - c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) - } - def convertTableProperties(c: CreateTableAsSelectStatement): Map[String, String] = { convertTableProperties( c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) @@ -323,7 +318,7 @@ private[sql] object CatalogV2Util { convertTableProperties(r.properties, r.options, r.serde, r.location, r.comment, r.provider) } - private def convertTableProperties( + def convertTableProperties( properties: Map[String, String], options: Map[String, String], serdeInfo: Option[SerdeInfo], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 46dd489ea470..4ac6c3f9b134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -248,10 +248,6 @@ object QueryParsingErrors { new ParseException("Either PROPERTIES or DBPROPERTIES is allowed.", ctx) } - def fromOrInNotAllowedInShowDatabasesError(ctx: ShowNamespacesContext): Throwable = { - new ParseException(s"FROM/IN operator is not allowed in SHOW DATABASES", ctx) - } - def cannotCleanReservedTablePropertyError( property: String, ctx: ParserRuleContext, msg: String): Throwable = { new ParseException(s"$property is a reserved table property, $msg.", ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 96262f5afbcd..d30bcd5af5da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -502,7 +502,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { } // stream-stream inner join doesn't emit late rows, whereas outer joins could - Seq((Inner, false), (LeftOuter, true), (RightOuter, true)).map { + Seq((Inner, false), (LeftOuter, true), (RightOuter, true)).foreach { case (joinType, expectFailure) => assertPassOnGlobalWatermarkLimit( s"single $joinType join in Append mode", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 70227bb0554f..f4ab8076938d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -717,8 +717,8 @@ class DDLParserSuite extends AnalysisTest { val parsedPlan = parsePlan(sqlStatement) val newTableToken = sqlStatement.split(" ")(0).trim.toUpperCase(Locale.ROOT) parsedPlan match { - case create: CreateTableStatement if newTableToken == "CREATE" => - assert(create.ifNotExists == expectedIfNotExists) + case create: CreateTable if newTableToken == "CREATE" => + assert(create.ignoreIfExists == expectedIfNotExists) case ctas: CreateTableAsSelectStatement if newTableToken == "CREATE" => assert(ctas.ifNotExists == expectedIfNotExists) case replace: ReplaceTableStatement if newTableToken == "REPLACE" => @@ -2285,19 +2285,19 @@ class DDLParserSuite extends AnalysisTest { private object TableSpec { def apply(plan: LogicalPlan): TableSpec = { plan match { - case create: CreateTableStatement => + case create: CreateTable => TableSpec( - create.tableName, + create.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(create.tableSchema), create.partitioning, - create.bucketSpec, - create.properties, - create.provider, - create.options, - create.location, - create.comment, - create.serde, - create.external) + create.tableSpec.bucketSpec, + create.tableSpec.properties, + create.tableSpec.provider, + create.tableSpec.options, + create.tableSpec.location, + create.tableSpec.comment, + create.tableSpec.serde, + create.tableSpec.external) case replace: ReplaceTableStatement => TableSpec( replace.tableName, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala index 12b7de694bc7..e7041a71363c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala @@ -57,6 +57,9 @@ class UnionEstimationSuite extends StatsEstimationTestBase { val attrDecimal = AttributeReference("cdecimal", DecimalType(5, 4))() val attrDate = AttributeReference("cdate", DateType)() val attrTimestamp = AttributeReference("ctimestamp", TimestampType)() + val attrTimestampNTZ = AttributeReference("ctimestamp_ntz", TimestampNTZType)() + val attrYMInterval = AttributeReference("cyminterval", YearMonthIntervalType())() + val attrDTInterval = AttributeReference("cdtinterval", DayTimeIntervalType())() val s1 = 1.toShort val s2 = 4.toShort @@ -84,7 +87,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase { attrFloat -> ColumnStat(min = Some(1.1f), max = Some(4.1f)), attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.5))), attrDate -> ColumnStat(min = Some(1), max = Some(4)), - attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)))) + attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)), + attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(4L)), + attrYMInterval -> ColumnStat(min = Some(2), max = Some(5)), + attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L)))) val s3 = 2.toShort val s4 = 6.toShort @@ -118,7 +124,16 @@ class UnionEstimationSuite extends StatsEstimationTestBase { AttributeReference("cdate1", DateType)() -> ColumnStat(min = Some(3), max = Some(6)), AttributeReference("ctimestamp1", TimestampType)() -> ColumnStat( min = Some(3L), - max = Some(6L)))) + max = Some(6L)), + AttributeReference("ctimestamp_ntz1", TimestampNTZType)() -> ColumnStat( + min = Some(3L), + max = Some(6L)), + AttributeReference("cymtimestamp1", YearMonthIntervalType())() -> ColumnStat( + min = Some(4), + max = Some(8)), + AttributeReference("cdttimestamp1", DayTimeIntervalType())() -> ColumnStat( + min = Some(4L), + max = Some(8L)))) val child1 = StatsTestPlan( outputList = columnInfo.keys.toSeq.sortWith(_.exprId.id < _.exprId.id), @@ -147,7 +162,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase { attrFloat -> ColumnStat(min = Some(1.1f), max = Some(6.1f)), attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.9))), attrDate -> ColumnStat(min = Some(1), max = Some(6)), - attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L))))) + attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)), + attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(6L)), + attrYMInterval -> ColumnStat(min = Some(2), max = Some(8)), + attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L))))) assert(union.stats === expectedStats) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 3ebeacfa827f..fad6fe5fbe16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -173,8 +173,8 @@ class InMemoryTable( partitionSchema: StructType, from: Seq[Any], to: Seq[Any]): Boolean = { - val rows = dataMap.remove(from).getOrElse(new BufferedRows(from.mkString("/"))) - val newRows = new BufferedRows(to.mkString("/")) + val rows = dataMap.remove(from).getOrElse(new BufferedRows(from)) + val newRows = new BufferedRows(to) rows.rows.foreach { r => val newRow = new GenericInternalRow(r.numFields) for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) @@ -197,7 +197,7 @@ class InMemoryTable( protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { if (!dataMap.contains(key)) { - val emptyRows = new BufferedRows(key.toArray.mkString("/")) + val emptyRows = new BufferedRows(key) val rows = if (key.length == schema.length) { emptyRows.withRow(InternalRow.fromSeq(key)) } else emptyRows @@ -215,7 +215,7 @@ class InMemoryTable( val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) - .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + .getOrElse(key -> new BufferedRows(key).withRow(row)) addPartitionKey(key) }) this @@ -290,7 +290,7 @@ class InMemoryTable( case In(attrName, values) if attrName == partitioning.head.name => val matchingKeys = values.map(_.toString).toSet data = data.filter(partition => { - val key = partition.asInstanceOf[BufferedRows].key + val key = partition.asInstanceOf[BufferedRows].keyString matchingKeys.contains(key) }) @@ -508,8 +508,8 @@ object InMemoryTable { } } -class BufferedRows( - val key: String = "") extends WriterCommitMessage with InputPartition with Serializable { +class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage + with InputPartition with HasPartitionKey with Serializable { val rows = new mutable.ArrayBuffer[InternalRow]() def withRow(row: InternalRow): BufferedRows = { @@ -517,6 +517,12 @@ class BufferedRows( this } + def keyString(): String = key.toArray.mkString("/") + + override def partitionKey(): InternalRow = { + InternalRow.fromSeq(key) + } + def clear(): Unit = rows.clear() } @@ -538,7 +544,7 @@ private class BufferedRowsReader( private def addMetadata(row: InternalRow): InternalRow = { val metadataRow = new GenericInternalRow(metadataColumnNames.map { case "index" => index - case "_partition" => UTF8String.fromString(partition.key) + case "_partition" => UTF8String.fromString(partition.keyString) }.toArray) new JoinedRow(row, metadataRow) } diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt index 6578d5664cd3..c4cffd67b16a 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt @@ -2,251 +2,269 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz -SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13405 13422 24 1.2 852.3 1.0X -SQL Json 10723 10788 92 1.5 681.7 1.3X -SQL Parquet Vectorized 164 217 50 95.9 10.4 81.8X -SQL Parquet MR 2349 2440 129 6.7 149.3 5.7X -SQL ORC Vectorized 312 346 23 50.4 19.8 43.0X -SQL ORC MR 1610 1659 69 9.8 102.4 8.3X +SQL CSV 9999 10058 83 1.6 635.7 1.0X +SQL Json 8857 8883 37 1.8 563.1 1.1X +SQL Parquet Vectorized 132 157 16 119.0 8.4 75.7X +SQL Parquet MR 1987 1997 14 7.9 126.3 5.0X +SQL ORC Vectorized 186 227 34 84.3 11.9 53.6X +SQL ORC MR 1559 1602 62 10.1 99.1 6.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized 110 117 9 143.0 7.0 1.0X +ParquetReader Vectorized -> Row 57 59 3 276.2 3.6 1.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SQL CSV 12897 12916 28 1.2 819.9 1.0X +SQL Json 9739 9770 44 1.6 619.2 1.3X +SQL Parquet Vectorized 226 237 14 69.7 14.3 57.2X +SQL Parquet MR 2124 2127 4 7.4 135.1 6.1X +SQL ORC Vectorized 213 250 39 73.9 13.5 60.6X +SQL ORC MR 1535 1548 19 10.2 97.6 8.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 187 209 20 84.3 11.9 1.0X -ParquetReader Vectorized -> Row 89 95 5 177.6 5.6 2.1X +ParquetReader Vectorized 259 269 15 60.6 16.5 1.0X +ParquetReader Vectorized -> Row 168 184 33 93.9 10.7 1.5X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14214 14549 474 1.1 903.7 1.0X -SQL Json 11866 11934 95 1.3 754.4 1.2X -SQL Parquet Vectorized 294 342 53 53.6 18.7 48.4X -SQL Parquet MR 2929 3004 107 5.4 186.2 4.9X -SQL ORC Vectorized 312 328 15 50.4 19.8 45.5X -SQL ORC MR 2037 2097 84 7.7 129.5 7.0X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 12765 12774 13 1.2 811.6 1.0X +SQL Json 10144 10158 21 1.6 644.9 1.3X +SQL Parquet Vectorized 168 208 34 93.7 10.7 76.1X +SQL Parquet MR 2443 2458 21 6.4 155.3 5.2X +SQL ORC Vectorized 300 313 16 52.4 19.1 42.5X +SQL ORC MR 1736 1780 62 9.1 110.4 7.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 249 266 18 63.1 15.8 1.0X -ParquetReader Vectorized -> Row 192 247 36 82.1 12.2 1.3X +ParquetReader Vectorized 229 239 9 68.6 14.6 1.0X +ParquetReader Vectorized -> Row 224 265 26 70.2 14.3 1.0X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15502 15817 446 1.0 985.6 1.0X -SQL Json 12638 12646 11 1.2 803.5 1.2X -SQL Parquet Vectorized 193 256 44 81.7 12.2 80.5X -SQL Parquet MR 2943 2953 14 5.3 187.1 5.3X -SQL ORC Vectorized 324 370 34 48.5 20.6 47.8X -SQL ORC MR 2110 2163 75 7.5 134.1 7.3X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 14055 14060 6 1.1 893.6 1.0X +SQL Json 10692 10738 64 1.5 679.8 1.3X +SQL Parquet Vectorized 167 223 34 94.0 10.6 84.0X +SQL Parquet MR 2416 2482 94 6.5 153.6 5.8X +SQL ORC Vectorized 329 344 12 47.8 20.9 42.7X +SQL ORC MR 1773 1789 23 8.9 112.7 7.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 276 287 14 57.0 17.6 1.0X -ParquetReader Vectorized -> Row 309 320 9 50.9 19.6 0.9X +ParquetReader Vectorized 232 239 9 67.9 14.7 1.0X +ParquetReader Vectorized -> Row 262 295 23 60.1 16.6 0.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 20156 20694 761 0.8 1281.5 1.0X -SQL Json 15228 15380 214 1.0 968.2 1.3X -SQL Parquet Vectorized 325 346 20 48.4 20.7 62.0X -SQL Parquet MR 3144 3228 118 5.0 199.9 6.4X -SQL ORC Vectorized 516 526 7 30.5 32.8 39.0X -SQL ORC MR 2353 2367 19 6.7 149.6 8.6X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 18964 18975 17 0.8 1205.7 1.0X +SQL Json 13173 13189 23 1.2 837.5 1.4X +SQL Parquet Vectorized 278 290 11 56.6 17.7 68.2X +SQL Parquet MR 2565 2589 34 6.1 163.1 7.4X +SQL ORC Vectorized 432 481 48 36.4 27.5 43.9X +SQL ORC MR 2052 2061 12 7.7 130.5 9.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 372 396 24 42.3 23.6 1.0X -ParquetReader Vectorized -> Row 437 462 25 36.0 27.8 0.9X +ParquetReader Vectorized 296 321 29 53.2 18.8 1.0X +ParquetReader Vectorized -> Row 329 335 7 47.7 20.9 0.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17413 17599 263 0.9 1107.1 1.0X -SQL Json 14416 14453 53 1.1 916.5 1.2X -SQL Parquet Vectorized 181 225 35 86.8 11.5 96.1X -SQL Parquet MR 2940 2996 78 5.3 186.9 5.9X -SQL ORC Vectorized 470 494 29 33.5 29.9 37.1X -SQL ORC MR 2351 2379 39 6.7 149.5 7.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 15092 15095 5 1.0 959.5 1.0X +SQL Json 12166 12169 5 1.3 773.5 1.2X +SQL Parquet Vectorized 161 198 27 97.4 10.3 93.5X +SQL Parquet MR 2407 2412 6 6.5 153.0 6.3X +SQL ORC Vectorized 476 509 30 33.1 30.2 31.7X +SQL ORC MR 1978 1981 5 8.0 125.7 7.6X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 268 282 14 58.7 17.0 1.0X -ParquetReader Vectorized -> Row 298 321 18 52.8 18.9 0.9X +ParquetReader Vectorized 256 261 9 61.4 16.3 1.0X +ParquetReader Vectorized -> Row 210 257 22 74.7 13.4 1.2X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 21666 21697 43 0.7 1377.5 1.0X -SQL Json 18307 18363 79 0.9 1163.9 1.2X -SQL Parquet Vectorized 310 337 22 50.7 19.7 69.9X -SQL Parquet MR 3089 3103 19 5.1 196.4 7.0X -SQL ORC Vectorized 589 617 31 26.7 37.5 36.8X -SQL ORC MR 2307 2377 98 6.8 146.7 9.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 19785 19786 1 0.8 1257.9 1.0X +SQL Json 16339 16340 1 1.0 1038.8 1.2X +SQL Parquet Vectorized 284 302 19 55.4 18.1 69.7X +SQL Parquet MR 2570 2576 8 6.1 163.4 7.7X +SQL ORC Vectorized 473 519 32 33.3 30.0 41.9X +SQL ORC MR 2136 2142 9 7.4 135.8 9.3X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 400 415 18 39.3 25.4 1.0X -ParquetReader Vectorized -> Row 393 406 11 40.1 25.0 1.0X +ParquetReader Vectorized 298 351 32 52.8 18.9 1.0X +ParquetReader Vectorized -> Row 370 375 9 42.5 23.5 0.8X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17703 17719 22 0.6 1688.3 1.0X -SQL Json 13095 13168 103 0.8 1248.9 1.4X -SQL Parquet Vectorized 2253 2266 19 4.7 214.8 7.9X -SQL Parquet MR 4913 4977 91 2.1 468.5 3.6X -SQL ORC Vectorized 2457 2467 14 4.3 234.3 7.2X -SQL ORC MR 4433 4464 44 2.4 422.8 4.0X +SQL CSV 13811 13824 18 0.8 1317.1 1.0X +SQL Json 11546 11589 61 0.9 1101.1 1.2X +SQL Parquet Vectorized 2143 2164 30 4.9 204.4 6.4X +SQL Parquet MR 4369 4386 24 2.4 416.7 3.2X +SQL ORC Vectorized 2289 2294 8 4.6 218.3 6.0X +SQL ORC MR 3770 3847 109 2.8 359.5 3.7X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9741 9804 89 1.1 929.0 1.0X -SQL Json 8230 8401 241 1.3 784.9 1.2X -SQL Parquet Vectorized 618 650 31 17.0 58.9 15.8X -SQL Parquet MR 2258 2311 75 4.6 215.4 4.3X -SQL ORC Vectorized 608 629 15 17.3 58.0 16.0X -SQL ORC MR 2466 2479 18 4.3 235.2 4.0X +SQL CSV 7344 7377 47 1.4 700.3 1.0X +SQL Json 7117 7153 51 1.5 678.7 1.0X +SQL Parquet Vectorized 598 618 18 17.5 57.0 12.3X +SQL Parquet MR 1955 1969 20 5.4 186.5 3.8X +SQL ORC Vectorized 559 565 8 18.8 53.3 13.1X +SQL ORC MR 1923 1932 13 5.5 183.4 3.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - CSV 24195 24573 534 0.7 1538.3 1.0X -Data column - Json 14746 14883 194 1.1 937.5 1.6X -Data column - Parquet Vectorized 352 385 34 44.7 22.4 68.7X -Data column - Parquet MR 3674 3694 27 4.3 233.6 6.6X -Data column - ORC Vectorized 480 505 26 32.8 30.5 50.4X -Data column - ORC MR 2913 3004 128 5.4 185.2 8.3X -Partition column - CSV 7527 7544 23 2.1 478.6 3.2X -Partition column - Json 11955 12051 135 1.3 760.1 2.0X -Partition column - Parquet Vectorized 65 92 29 242.5 4.1 373.0X -Partition column - Parquet MR 1614 1628 21 9.7 102.6 15.0X -Partition column - ORC Vectorized 71 99 29 220.1 4.5 338.5X -Partition column - ORC MR 1761 1769 11 8.9 112.0 13.7X -Both columns - CSV 24077 24127 70 0.7 1530.8 1.0X -Both columns - Json 15286 15479 273 1.0 971.9 1.6X -Both columns - Parquet Vectorized 376 412 40 41.9 23.9 64.4X -Both columns - Parquet MR 3808 3826 26 4.1 242.1 6.4X -Both columns - ORC Vectorized 560 604 42 28.1 35.6 43.2X -Both columns - ORC MR 3046 3080 49 5.2 193.7 7.9X +Data column - CSV 19266 19281 21 0.8 1224.9 1.0X +Data column - Json 13119 13126 10 1.2 834.1 1.5X +Data column - Parquet Vectorized 305 334 27 51.6 19.4 63.2X +Data column - Parquet MR 2978 3022 63 5.3 189.3 6.5X +Data column - ORC Vectorized 446 480 32 35.3 28.3 43.2X +Data column - ORC MR 2451 2469 24 6.4 155.9 7.9X +Partition column - CSV 6640 6641 1 2.4 422.2 2.9X +Partition column - Json 10485 10512 37 1.5 666.6 1.8X +Partition column - Parquet Vectorized 65 88 24 241.2 4.1 295.4X +Partition column - Parquet MR 1403 1434 44 11.2 89.2 13.7X +Partition column - ORC Vectorized 62 86 21 253.8 3.9 310.9X +Partition column - ORC MR 1523 1525 3 10.3 96.8 12.6X +Both columns - CSV 19347 19354 10 0.8 1230.0 1.0X +Both columns - Json 13788 13793 6 1.1 876.6 1.4X +Both columns - Parquet Vectorized 346 414 70 45.5 22.0 55.7X +Both columns - Parquet MR 3022 3032 14 5.2 192.1 6.4X +Both columns - ORC Vectorized 479 519 28 32.9 30.4 40.2X +Both columns - ORC MR 2539 2540 1 6.2 161.4 7.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11805 12021 306 0.9 1125.8 1.0X -SQL Json 12051 12105 77 0.9 1149.3 1.0X -SQL Parquet Vectorized 1474 1545 100 7.1 140.6 8.0X -SQL Parquet MR 4488 4492 4 2.3 428.1 2.6X -ParquetReader Vectorized 1140 1140 1 9.2 108.7 10.4X -SQL ORC Vectorized 1164 1178 20 9.0 111.0 10.1X -SQL ORC MR 3745 3817 102 2.8 357.1 3.2X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 9158 9163 8 1.1 873.3 1.0X +SQL Json 10429 10448 27 1.0 994.6 0.9X +SQL Parquet Vectorized 1363 1660 420 7.7 130.0 6.7X +SQL Parquet MR 3894 3898 5 2.7 371.4 2.4X +ParquetReader Vectorized 1021 1031 14 10.3 97.4 9.0X +SQL ORC Vectorized 1168 1191 33 9.0 111.4 7.8X +SQL ORC MR 3267 3287 28 3.2 311.6 2.8X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9814 9837 33 1.1 936.0 1.0X -SQL Json 9317 9445 182 1.1 888.5 1.1X -SQL Parquet Vectorized 1117 1155 52 9.4 106.6 8.8X -SQL Parquet MR 3463 3538 106 3.0 330.3 2.8X -ParquetReader Vectorized 1033 1039 8 10.1 98.6 9.5X -SQL ORC Vectorized 1307 1353 65 8.0 124.7 7.5X -SQL ORC MR 3644 3690 65 2.9 347.5 2.7X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 7570 7577 11 1.4 721.9 1.0X +SQL Json 8085 8096 14 1.3 771.1 0.9X +SQL Parquet Vectorized 1097 1101 5 9.6 104.7 6.9X +SQL Parquet MR 2999 3014 21 3.5 286.0 2.5X +ParquetReader Vectorized 1052 1064 18 10.0 100.3 7.2X +SQL ORC Vectorized 1286 2162 1239 8.2 122.6 5.9X +SQL ORC MR 3053 3123 100 3.4 291.1 2.5X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8145 8270 176 1.3 776.8 1.0X -SQL Json 5714 5764 71 1.8 544.9 1.4X -SQL Parquet Vectorized 235 264 15 44.6 22.4 34.7X -SQL Parquet MR 2398 2412 19 4.4 228.7 3.4X -ParquetReader Vectorized 248 262 11 42.3 23.6 32.9X -SQL ORC Vectorized 430 462 37 24.4 41.0 18.9X -SQL ORC MR 1983 1993 14 5.3 189.1 4.1X +SQL CSV 6211 6214 3 1.7 592.4 1.0X +SQL Json 4977 4994 24 2.1 474.6 1.2X +SQL Parquet Vectorized 260 272 10 40.3 24.8 23.9X +SQL Parquet MR 1981 1985 5 5.3 188.9 3.1X +ParquetReader Vectorized 268 276 11 39.1 25.6 23.2X +SQL ORC Vectorized 428 457 35 24.5 40.8 14.5X +SQL ORC MR 1696 1705 12 6.2 161.8 3.7X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 2448 2461 18 0.4 2334.3 1.0X -SQL Json 3332 3370 53 0.3 3177.6 0.7X -SQL Parquet Vectorized 51 87 25 20.7 48.2 48.4X -SQL Parquet MR 239 278 35 4.4 227.5 10.3X -SQL ORC Vectorized 60 82 19 17.5 57.3 40.8X -SQL ORC MR 197 219 26 5.3 188.3 12.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 2067 2093 36 0.5 1971.6 1.0X +SQL Json 3047 5663 NaN 0.3 2906.0 0.7X +SQL Parquet Vectorized 50 73 21 20.9 47.7 41.3X +SQL Parquet MR 205 224 28 5.1 195.3 10.1X +SQL ORC Vectorized 60 79 23 17.4 57.5 34.3X +SQL ORC MR 173 196 25 6.1 165.1 11.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 6034 6061 39 0.2 5754.0 1.0X -SQL Json 12232 12315 118 0.1 11665.4 0.5X -SQL Parquet Vectorized 73 120 30 14.4 69.6 82.6X -SQL Parquet MR 316 368 44 3.3 301.1 19.1X -SQL ORC Vectorized 76 122 36 13.7 72.9 79.0X -SQL ORC MR 206 261 47 5.1 196.5 29.3X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 4841 4844 5 0.2 4616.4 1.0X +SQL Json 11721 11745 34 0.1 11177.9 0.4X +SQL Parquet Vectorized 67 101 27 15.7 63.8 72.4X +SQL Parquet MR 225 247 27 4.7 214.2 21.5X +SQL ORC Vectorized 75 99 26 13.9 71.7 64.4X +SQL ORC MR 192 219 26 5.5 183.4 25.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10307 10309 4 0.1 9829.0 1.0X -SQL Json 23412 23539 180 0.0 22327.7 0.4X -SQL Parquet Vectorized 105 151 23 10.0 99.9 98.4X -SQL Parquet MR 295 325 29 3.6 281.5 34.9X -SQL ORC Vectorized 85 112 31 12.4 81.0 121.4X -SQL ORC MR 212 255 66 4.9 202.3 48.6X +SQL CSV 8410 8414 5 0.1 8020.8 1.0X +SQL Json 22537 22923 547 0.0 21492.8 0.4X +SQL Parquet Vectorized 101 141 32 10.4 96.2 83.4X +SQL Parquet MR 262 289 45 4.0 249.9 32.1X +SQL ORC Vectorized 90 113 32 11.7 85.4 93.9X +SQL ORC MR 210 232 36 5.0 200.3 40.0X diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index fe083703ae0e..65db1afc5119 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,251 +2,269 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15943 15956 18 1.0 1013.6 1.0X -SQL Json 9109 9158 70 1.7 579.1 1.8X -SQL Parquet Vectorized 168 191 16 93.8 10.7 95.1X -SQL Parquet MR 1938 1950 17 8.1 123.2 8.2X -SQL ORC Vectorized 191 199 6 82.2 12.2 83.3X -SQL ORC MR 1523 1537 20 10.3 96.8 10.5X +SQL CSV 11497 11744 349 1.4 731.0 1.0X +SQL Json 7073 7099 37 2.2 449.7 1.6X +SQL Parquet Vectorized 105 126 17 149.9 6.7 109.6X +SQL Parquet MR 1647 1648 2 9.6 104.7 7.0X +SQL ORC Vectorized 157 167 5 100.0 10.0 73.1X +SQL ORC MR 1466 1485 27 10.7 93.2 7.8X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized 114 123 8 137.8 7.3 1.0X +ParquetReader Vectorized -> Row 42 44 1 372.1 2.7 2.7X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SQL CSV 15825 15961 193 1.0 1006.1 1.0X +SQL Json 7966 8054 125 2.0 506.5 2.0X +SQL Parquet Vectorized 136 148 9 115.4 8.7 116.1X +SQL Parquet MR 1814 1825 15 8.7 115.4 8.7X +SQL ORC Vectorized 138 147 6 114.4 8.7 115.1X +SQL ORC MR 1299 1382 117 12.1 82.6 12.2X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 203 206 3 77.5 12.9 1.0X -ParquetReader Vectorized -> Row 97 100 2 161.6 6.2 2.1X +ParquetReader Vectorized 179 185 9 88.0 11.4 1.0X +ParquetReader Vectorized -> Row 91 101 3 172.6 5.8 2.0X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17062 17089 38 0.9 1084.8 1.0X -SQL Json 9718 9720 3 1.6 617.9 1.8X -SQL Parquet Vectorized 326 333 7 48.2 20.7 52.3X -SQL Parquet MR 2305 2329 34 6.8 146.6 7.4X -SQL ORC Vectorized 201 205 3 78.2 12.8 84.8X -SQL ORC MR 1795 1796 0 8.8 114.1 9.5X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 15449 16211 1077 1.0 982.2 1.0X +SQL Json 7955 8292 476 2.0 505.8 1.9X +SQL Parquet Vectorized 195 211 8 80.7 12.4 79.2X +SQL Parquet MR 1866 1890 33 8.4 118.7 8.3X +SQL ORC Vectorized 163 173 8 96.6 10.4 94.9X +SQL ORC MR 1550 1555 8 10.1 98.5 10.0X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 333 339 7 47.2 21.2 1.0X -ParquetReader Vectorized -> Row 283 285 3 55.7 18.0 1.2X +ParquetReader Vectorized 299 302 4 52.5 19.0 1.0X +ParquetReader Vectorized -> Row 264 280 14 59.6 16.8 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 18722 18809 123 0.8 1190.3 1.0X -SQL Json 10192 10249 80 1.5 648.0 1.8X -SQL Parquet Vectorized 155 162 8 101.6 9.8 120.9X -SQL Parquet MR 2348 2360 16 6.7 149.3 8.0X -SQL ORC Vectorized 265 275 7 59.3 16.9 70.5X -SQL ORC MR 1892 1938 65 8.3 120.3 9.9X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 16640 16834 273 0.9 1058.0 1.0X +SQL Json 8859 8862 3 1.8 563.3 1.9X +SQL Parquet Vectorized 144 155 8 109.0 9.2 115.3X +SQL Parquet MR 1960 2023 89 8.0 124.6 8.5X +SQL ORC Vectorized 218 233 11 72.3 13.8 76.5X +SQL ORC MR 1440 1442 3 10.9 91.6 11.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 243 251 7 64.8 15.4 1.0X -ParquetReader Vectorized -> Row 222 229 5 70.9 14.1 1.1X +ParquetReader Vectorized 224 241 13 70.2 14.2 1.0X +ParquetReader Vectorized -> Row 214 221 10 73.6 13.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 24299 24358 84 0.6 1544.9 1.0X -SQL Json 13349 13429 114 1.2 848.7 1.8X -SQL Parquet Vectorized 215 241 59 73.3 13.6 113.2X -SQL Parquet MR 2508 2508 0 6.3 159.4 9.7X -SQL ORC Vectorized 323 330 6 48.7 20.5 75.2X -SQL ORC MR 1993 2009 22 7.9 126.7 12.2X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 22998 23324 461 0.7 1462.2 1.0X +SQL Json 12165 12179 20 1.3 773.4 1.9X +SQL Parquet Vectorized 237 265 69 66.3 15.1 96.9X +SQL Parquet MR 2199 2199 0 7.2 139.8 10.5X +SQL ORC Vectorized 303 311 10 51.9 19.3 76.0X +SQL ORC MR 1750 1763 18 9.0 111.3 13.1X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 310 351 74 50.8 19.7 1.0X -ParquetReader Vectorized -> Row 281 297 8 55.9 17.9 1.1X +ParquetReader Vectorized 331 368 80 47.6 21.0 1.0X +ParquetReader Vectorized -> Row 314 318 6 50.0 20.0 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19745 19811 93 0.8 1255.4 1.0X -SQL Json 12523 12760 335 1.3 796.2 1.6X -SQL Parquet Vectorized 153 160 6 102.9 9.7 129.2X -SQL Parquet MR 2325 2338 18 6.8 147.8 8.5X -SQL ORC Vectorized 389 401 8 40.5 24.7 50.8X -SQL ORC MR 2009 2009 1 7.8 127.7 9.8X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 17442 18560 1581 0.9 1108.9 1.0X +SQL Json 10833 11056 315 1.5 688.8 1.6X +SQL Parquet Vectorized 150 162 10 105.0 9.5 116.5X +SQL Parquet MR 1804 1922 167 8.7 114.7 9.7X +SQL ORC Vectorized 317 336 20 49.6 20.2 55.0X +SQL ORC MR 1550 1648 139 10.1 98.5 11.3X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 240 244 4 65.5 15.3 1.0X -ParquetReader Vectorized -> Row 223 230 6 70.5 14.2 1.1X +ParquetReader Vectorized 240 263 11 65.7 15.2 1.0X +ParquetReader Vectorized -> Row 224 235 15 70.4 14.2 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 27223 27293 99 0.6 1730.8 1.0X -SQL Json 18601 18646 63 0.8 1182.6 1.5X -SQL Parquet Vectorized 247 251 3 63.8 15.7 110.4X -SQL Parquet MR 2724 2773 69 5.8 173.2 10.0X -SQL ORC Vectorized 474 484 10 33.2 30.1 57.4X -SQL ORC MR 2342 2368 37 6.7 148.9 11.6X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 22438 23472 1462 0.7 1426.5 1.0X +SQL Json 15839 15888 70 1.0 1007.0 1.4X +SQL Parquet Vectorized 215 229 12 73.3 13.6 104.6X +SQL Parquet MR 1928 2061 188 8.2 122.6 11.6X +SQL ORC Vectorized 393 421 17 40.0 25.0 57.0X +SQL ORC MR 1799 1814 22 8.7 114.4 12.5X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 326 335 13 48.3 20.7 1.0X -ParquetReader Vectorized -> Row 358 365 7 44.0 22.7 0.9X +ParquetReader Vectorized 310 316 9 50.7 19.7 1.0X +ParquetReader Vectorized -> Row 289 302 20 54.3 18.4 1.1X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 18706 18716 15 0.6 1783.9 1.0X -SQL Json 12665 12762 138 0.8 1207.8 1.5X -SQL Parquet Vectorized 2408 2419 15 4.4 229.6 7.8X -SQL Parquet MR 4599 4620 30 2.3 438.6 4.1X -SQL ORC Vectorized 2397 2400 3 4.4 228.6 7.8X -SQL ORC MR 4267 4288 30 2.5 406.9 4.4X +SQL CSV 15669 15869 283 0.7 1494.3 1.0X +SQL Json 10126 10559 613 1.0 965.7 1.5X +SQL Parquet Vectorized 2056 2064 11 5.1 196.0 7.6X +SQL Parquet MR 3918 3927 13 2.7 373.6 4.0X +SQL ORC Vectorized 1786 1887 143 5.9 170.3 8.8X +SQL ORC MR 3521 3555 48 3.0 335.8 4.4X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10822 10838 23 1.0 1032.0 1.0X -SQL Json 7459 7488 41 1.4 711.4 1.5X -SQL Parquet Vectorized 875 895 26 12.0 83.5 12.4X -SQL Parquet MR 1976 2002 37 5.3 188.4 5.5X -SQL ORC Vectorized 533 539 8 19.7 50.9 20.3X -SQL ORC MR 2191 2194 5 4.8 208.9 4.9X +SQL CSV 8659 8948 409 1.2 825.8 1.0X +SQL Json 6410 6536 177 1.6 611.3 1.4X +SQL Parquet Vectorized 655 709 47 16.0 62.4 13.2X +SQL Parquet MR 1528 1531 3 6.9 145.7 5.7X +SQL ORC Vectorized 388 416 24 27.0 37.0 22.3X +SQL ORC MR 1599 1700 142 6.6 152.5 5.4X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - CSV 31196 31449 359 0.5 1983.4 1.0X -Data column - Json 16118 16855 1041 1.0 1024.8 1.9X -Data column - Parquet Vectorized 243 251 9 64.8 15.4 128.4X -Data column - Parquet MR 4213 4288 106 3.7 267.8 7.4X -Data column - ORC Vectorized 335 341 4 46.9 21.3 93.1X -Data column - ORC MR 3119 3146 38 5.0 198.3 10.0X -Partition column - CSV 9616 9915 423 1.6 611.3 3.2X -Partition column - Json 14136 14164 39 1.1 898.8 2.2X -Partition column - Parquet Vectorized 64 70 6 243.9 4.1 483.8X -Partition column - Parquet MR 1954 1980 38 8.1 124.2 16.0X -Partition column - ORC Vectorized 67 74 8 233.4 4.3 462.9X -Partition column - ORC MR 2461 2479 26 6.4 156.4 12.7X -Both columns - CSV 30327 30666 479 0.5 1928.2 1.0X -Both columns - Json 18656 18789 188 0.8 1186.1 1.7X -Both columns - Parquet Vectorized 291 297 7 54.0 18.5 107.2X -Both columns - Parquet MR 4430 4443 19 3.6 281.6 7.0X -Both columns - ORC Vectorized 403 411 11 39.0 25.6 77.4X -Both columns - ORC MR 3580 3584 5 4.4 227.6 8.7X +Data column - CSV 21094 21357 372 0.7 1341.1 1.0X +Data column - Json 11163 11434 383 1.4 709.7 1.9X +Data column - Parquet Vectorized 225 238 13 69.9 14.3 93.7X +Data column - Parquet MR 2218 2342 175 7.1 141.0 9.5X +Data column - ORC Vectorized 276 300 20 56.9 17.6 76.4X +Data column - ORC MR 1851 1863 17 8.5 117.7 11.4X +Partition column - CSV 5834 6119 403 2.7 370.9 3.6X +Partition column - Json 9746 9754 11 1.6 619.6 2.2X +Partition column - Parquet Vectorized 57 61 2 273.9 3.7 367.4X +Partition column - Parquet MR 1164 1167 5 13.5 74.0 18.1X +Partition column - ORC Vectorized 60 64 3 261.3 3.8 350.4X +Partition column - ORC MR 1298 1304 8 12.1 82.5 16.2X +Both columns - CSV 22632 22636 4 0.7 1438.9 0.9X +Both columns - Json 12568 12587 26 1.3 799.1 1.7X +Both columns - Parquet Vectorized 283 288 7 55.5 18.0 74.4X +Both columns - Parquet MR 2547 2553 8 6.2 161.9 8.3X +Both columns - ORC Vectorized 343 346 4 45.8 21.8 61.5X +Both columns - ORC MR 2177 2178 2 7.2 138.4 9.7X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15606 15614 11 0.7 1488.3 1.0X -SQL Json 15406 15451 63 0.7 1469.3 1.0X -SQL Parquet Vectorized 1555 1573 25 6.7 148.3 10.0X -SQL Parquet MR 5369 5377 11 2.0 512.0 2.9X -ParquetReader Vectorized 1145 1150 7 9.2 109.2 13.6X -SQL ORC Vectorized 1023 1027 6 10.2 97.6 15.3X -SQL ORC MR 4421 4542 172 2.4 421.6 3.5X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 11364 11364 0 0.9 1083.7 1.0X +SQL Json 10555 10562 9 1.0 1006.6 1.1X +SQL Parquet Vectorized 1299 1309 13 8.1 123.9 8.7X +SQL Parquet MR 3350 3351 1 3.1 319.5 3.4X +ParquetReader Vectorized 983 987 5 10.7 93.8 11.6X +SQL ORC Vectorized 912 913 1 11.5 87.0 12.5X +SQL ORC MR 3056 3059 5 3.4 291.4 3.7X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11096 11159 90 0.9 1058.2 1.0X -SQL Json 10797 11304 717 1.0 1029.7 1.0X -SQL Parquet Vectorized 1218 1230 16 8.6 116.2 9.1X -SQL Parquet MR 3778 3806 40 2.8 360.3 2.9X -ParquetReader Vectorized 1108 1118 14 9.5 105.7 10.0X -SQL ORC Vectorized 1361 1371 13 7.7 129.8 8.2X -SQL ORC MR 4186 4196 14 2.5 399.2 2.7X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 8651 8654 5 1.2 825.0 1.0X +SQL Json 7791 7794 4 1.3 743.0 1.1X +SQL Parquet Vectorized 1045 1055 15 10.0 99.7 8.3X +SQL Parquet MR 2516 2519 3 4.2 240.0 3.4X +ParquetReader Vectorized 927 933 6 11.3 88.4 9.3X +SQL ORC Vectorized 1285 1286 2 8.2 122.5 6.7X +SQL ORC MR 3013 3013 0 3.5 287.4 2.9X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8803 8866 90 1.2 839.5 1.0X -SQL Json 7220 7249 42 1.5 688.5 1.2X -SQL Parquet Vectorized 258 265 7 40.6 24.6 34.1X -SQL Parquet MR 2760 2761 0 3.8 263.2 3.2X -ParquetReader Vectorized 277 283 5 37.8 26.4 31.7X -SQL ORC Vectorized 514 522 6 20.4 49.1 17.1X -SQL ORC MR 2523 2591 96 4.2 240.6 3.5X +SQL CSV 6272 6288 23 1.7 598.1 1.0X +SQL Json 4469 4469 0 2.3 426.2 1.4X +SQL Parquet Vectorized 231 235 7 45.4 22.0 27.2X +SQL Parquet MR 1673 1674 2 6.3 159.5 3.7X +ParquetReader Vectorized 243 244 3 43.1 23.2 25.8X +SQL ORC Vectorized 471 472 2 22.2 45.0 13.3X +SQL ORC MR 1606 1618 17 6.5 153.2 3.9X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 3022 3032 14 0.3 2881.9 1.0X -SQL Json 4047 4051 5 0.3 3859.5 0.7X -SQL Parquet Vectorized 50 54 6 20.8 48.1 59.9X -SQL Parquet MR 299 301 2 3.5 285.0 10.1X -SQL ORC Vectorized 59 63 11 17.9 55.9 51.6X -SQL ORC MR 255 259 5 4.1 243.4 11.8X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 2171 2173 2 0.5 2070.8 1.0X +SQL Json 2266 2278 17 0.5 2161.3 1.0X +SQL Parquet Vectorized 51 55 7 20.4 49.0 42.2X +SQL Parquet MR 190 192 2 5.5 180.9 11.4X +SQL ORC Vectorized 57 61 8 18.4 54.2 38.2X +SQL ORC MR 161 164 2 6.5 153.8 13.5X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 7250 7252 3 0.1 6914.4 1.0X -SQL Json 15641 15718 109 0.1 14916.8 0.5X -SQL Parquet Vectorized 66 72 8 15.9 62.9 110.0X -SQL Parquet MR 320 323 3 3.3 305.0 22.7X -SQL ORC Vectorized 72 77 11 14.6 68.6 100.9X -SQL ORC MR 269 273 5 3.9 256.8 26.9X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 5200 5211 15 0.2 4959.5 1.0X +SQL Json 8312 8318 8 0.1 7927.1 0.6X +SQL Parquet Vectorized 67 73 10 15.7 63.9 77.6X +SQL Parquet MR 210 214 4 5.0 200.4 24.8X +SQL ORC Vectorized 70 77 16 15.0 66.7 74.3X +SQL ORC MR 182 184 2 5.8 173.6 28.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10962 11340 535 0.1 10454.1 1.0X -SQL Json 24951 25755 1137 0.0 23795.0 0.4X -SQL Parquet Vectorized 84 93 6 12.4 80.5 129.9X -SQL Parquet MR 280 296 14 3.7 266.8 39.2X -SQL ORC Vectorized 70 76 6 15.0 66.6 156.9X -SQL ORC MR 231 242 13 4.5 220.1 47.5X +SQL CSV 9030 9032 2 0.1 8611.8 1.0X +SQL Json 15429 15462 46 0.1 14714.5 0.6X +SQL Parquet Vectorized 91 97 8 11.5 87.2 98.8X +SQL Parquet MR 235 239 3 4.5 224.2 38.4X +SQL ORC Vectorized 80 84 9 13.1 76.4 112.8X +SQL ORC MR 192 201 7 5.5 183.4 47.0X diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 40ed0b2454c1..fabdf39533d3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -48,6 +48,9 @@ public class OrcColumnarBatchReader extends RecordReader { // The capacity of vectorized batch. private int capacity; + // If the Orc file to be read is written by Spark 3.3 or after, use UTC timestamp. + private boolean useUTCTimestamp; + // Vectorized ORC Row Batch wrap. private VectorizedRowBatchWrap wrap; @@ -74,8 +77,9 @@ public class OrcColumnarBatchReader extends RecordReader { // The wrapped ORC column vectors. private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; - public OrcColumnarBatchReader(int capacity) { + public OrcColumnarBatchReader(int capacity, boolean useUTCTimestamp) { this.capacity = capacity; + this.useUTCTimestamp = useUTCTimestamp; } @@ -124,7 +128,8 @@ public void initialize( fileSplit.getPath(), OrcFile.readerOptions(conf) .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) - .filesystem(fileSplit.getPath().getFileSystem(conf))); + .filesystem(fileSplit.getPath().getFileSystem(conf)) + .useUTCTimestamp(useUTCTimestamp)); Reader.Options options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); recordReader = reader.rows(options); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 39591be3b4be..0eb5d65a4a8f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -53,19 +53,52 @@ public void skip() { throw new UnsupportedOperationException(); } + private void updateCurrentByte() { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + @Override public final void readBooleans(int total, WritableColumnVector c, int rowId) { - // TODO: properly vectorize this - for (int i = 0; i < total; i++) { - c.putBoolean(rowId + i, readBoolean()); + int i = 0; + if (bitOffset > 0) { + i = Math.min(8 - bitOffset, total); + c.putBooleans(rowId, i, currentByte, bitOffset); + bitOffset = (bitOffset + i) & 7; + } + for (; i + 7 < total; i += 8) { + updateCurrentByte(); + c.putBooleans(rowId + i, currentByte); + } + if (i < total) { + updateCurrentByte(); + bitOffset = total - i; + c.putBooleans(rowId + i, bitOffset, currentByte, 0); } } @Override public final void skipBooleans(int total) { - // TODO: properly vectorize this - for (int i = 0; i < total; i++) { - readBoolean(); + int i = 0; + if (bitOffset > 0) { + i = Math.min(8 - bitOffset, total); + bitOffset = (bitOffset + i) & 7; + } + if (i + 7 < total) { + int numBytesToSkip = (total - i) / 8; + try { + in.skipFully(numBytesToSkip); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to skip bytes", e); + } + i += numBytesToSkip * 8; + } + if (i < total) { + updateCurrentByte(); + bitOffset = total - i; } } @@ -276,13 +309,8 @@ public void skipShorts(int total) { @Override public final boolean readBoolean() { - // TODO: vectorize decoding and keep boolean[] instead of currentByte if (bitOffset == 0) { - try { - currentByte = (byte) in.read(); - } catch (IOException e) { - throw new ParquetDecodingException("Failed to read a byte", e); - } + updateCurrentByte(); } boolean v = (currentByte & (1 << bitOffset)) != 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index f7c9dc55f7ec..bbe96819a618 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -152,6 +152,18 @@ public void putBooleans(int rowId, int count, boolean value) { } } + @Override + public void putBooleans(int rowId, byte src) { + Platform.putByte(null, data + rowId, (byte)(src & 1)); + Platform.putByte(null, data + rowId + 1, (byte)(src >>> 1 & 1)); + Platform.putByte(null, data + rowId + 2, (byte)(src >>> 2 & 1)); + Platform.putByte(null, data + rowId + 3, (byte)(src >>> 3 & 1)); + Platform.putByte(null, data + rowId + 4, (byte)(src >>> 4 & 1)); + Platform.putByte(null, data + rowId + 5, (byte)(src >>> 5 & 1)); + Platform.putByte(null, data + rowId + 6, (byte)(src >>> 6 & 1)); + Platform.putByte(null, data + rowId + 7, (byte)(src >>> 7 & 1)); + } + @Override public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3fb96d872cd8..833a93f2a2bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -147,6 +147,18 @@ public void putBooleans(int rowId, int count, boolean value) { } } + @Override + public void putBooleans(int rowId, byte src) { + byteData[rowId] = (byte)(src & 1); + byteData[rowId + 1] = (byte)(src >>> 1 & 1); + byteData[rowId + 2] = (byte)(src >>> 2 & 1); + byteData[rowId + 3] = (byte)(src >>> 3 & 1); + byteData[rowId + 4] = (byte)(src >>> 4 & 1); + byteData[rowId + 5] = (byte)(src >>> 5 & 1); + byteData[rowId + 6] = (byte)(src >>> 6 & 1); + byteData[rowId + 7] = (byte)(src >>> 7 & 1); + } + @Override public boolean getBoolean(int rowId) { return byteData[rowId] == 1; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8f7dcf237440..5e01c372793f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -46,6 +46,7 @@ * WritableColumnVector are intended to be reused. */ public abstract class WritableColumnVector extends ColumnVector { + private final byte[] byte8 = new byte[8]; /** * Resets this column for writing. The currently stored values are no longer accessible. @@ -201,6 +202,29 @@ public WritableColumnVector reserveDictionaryIds(int capacity) { */ public abstract void putBooleans(int rowId, int count, boolean value); + /** + * Sets bits from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) + * src must contain bit-packed 8 booleans in the byte. + */ + public void putBooleans(int rowId, int count, byte src, int srcIndex) { + assert ((srcIndex + count) <= 8); + byte8[0] = (byte)(src & 1); + byte8[1] = (byte)(src >>> 1 & 1); + byte8[2] = (byte)(src >>> 2 & 1); + byte8[3] = (byte)(src >>> 3 & 1); + byte8[4] = (byte)(src >>> 4 & 1); + byte8[5] = (byte)(src >>> 5 & 1); + byte8[6] = (byte)(src >>> 6 & 1); + byte8[7] = (byte)(src >>> 7 & 1); + putBytes(rowId, count, byte8, srcIndex); + } + + /** + * Sets bits from [src[0], src[7]] to [rowId, rowId + 7] + * src must contain bit-packed 8 booleans in the byte. + */ + public abstract void putBooleans(int rowId, byte src); + /** * Sets `value` to the value at rowId. */ @@ -470,6 +494,18 @@ public final int appendBooleans(int count, boolean v) { return result; } + /** + * Append bits from [src[offset], src[offset + count]) + * src must contain bit-packed 8 booleans in the byte. + */ + public final int appendBooleans(int count, byte src, int offset) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, src, offset); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 63812b873ba8..df110aa269e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -97,13 +97,17 @@ class SparkSession private( * since that would cause every new session to reinvoke Spark Session Extensions on the currently * running extensions. */ - private[sql] def this(sc: SparkContext) = { + private[sql] def this( + sc: SparkContext, + initialSessionOptions: java.util.HashMap[String, String]) = { this(sc, None, None, SparkSession.applyExtensions( sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), - new SparkSessionExtensions), Map.empty) + new SparkSessionExtensions), initialSessionOptions.asScala.toMap) } + private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) + private[sql] val sessionUUID: String = UUID.randomUUID.toString sparkContext.assertNotStopped() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala index 7404a30fed71..3f9eb5c8084e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, CreateV2Table, LogicalPlan, ReplaceColumns, ReplaceTable} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, CreateTable, LogicalPlan, ReplaceColumns, ReplaceTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableChangeColumnCommand, CreateDataSourceTableCommand, CreateTableCommand} @@ -31,7 +31,7 @@ object ReplaceCharWithVarchar extends Rule[LogicalPlan] { plan.resolveOperators { // V2 commands - case cmd: CreateV2Table => + case cmd: CreateTable => cmd.copy(tableSchema = replaceCharWithVarcharInSchema(cmd.tableSchema)) case cmd: ReplaceTable => cmd.copy(tableSchema = replaceCharWithVarcharInSchema(cmd.tableSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 5362b6bf6974..0940982f7a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{CreateTable => CatalystCreateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, LookupCatalog, SupportsNamespaces, V1Table} @@ -143,25 +144,24 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) // For CREATE TABLE [AS SELECT], we should use the v1 command if the catalog is resolved to the // session catalog and the table provider is not v2. - case c @ CreateTableStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => + case c @ CatalystCreateTable(ResolvedDBObjectName(catalog, name), _, _, _, _) => val (storageFormat, provider) = getStorageFormatAndProvider( - c.provider, c.options, c.location, c.serde, ctas = false) - if (!isV2Provider(provider)) { - val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, - c.partitioning, c.bucketSpec, c.properties, provider, c.location, - c.comment, storageFormat, c.external) - val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + c.tableSpec.provider, + c.tableSpec.options, + c.tableSpec.location, + c.tableSpec.serde, + ctas = false) + if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + val tableDesc = buildCatalogTable(name.asTableIdentifier, c.tableSchema, + c.partitioning, c.tableSpec.bucketSpec, c.tableSpec.properties, provider, + c.tableSpec.location, c.tableSpec.comment, storageFormat, + c.tableSpec.external) + val mode = if (c.ignoreIfExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, None) } else { - CreateV2Table( - catalog.asTableCatalog, - tbl.asIdentifier, - c.tableSchema, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c), - ignoreIfExists = c.ifNotExists) + val newTableSpec = c.tableSpec.copy(bucketSpec = None) + c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), + tableSpec = newTableSpec) } case c @ CreateTableAsSelectStatement( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index c62670b227bc..748f75b18626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -87,6 +87,11 @@ object SQLExecution { val planDescriptionMode = ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val globalConfigs = sparkSession.sharedState.conf.getAll.toMap + val modifiedConfigs = sparkSession.sessionState.conf.getAllConfs + .filterNot(kv => globalConfigs.get(kv._1).contains(kv._2)) + val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) + withSQLConfPropagated(sparkSession) { var ex: Option[Throwable] = None val startTime = System.nanoTime() @@ -99,7 +104,8 @@ object SQLExecution { // `queryExecution.executedPlan` triggers query planning. If it fails, the exception // will be caught and reported in the `SparkListenerSQLExecutionEnd` sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), - time = System.currentTimeMillis())) + time = System.currentTimeMillis(), + redactedConfigs)) body } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 80ab07b15988..a7e505ebd93d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -128,7 +128,7 @@ case class DataSource( .getOrElse(true) } - bucketSpec.map { bucket => + bucketSpec.foreach { bucket => SchemaUtils.checkColumnNameDuplication( bucket.bucketColumnNames, "in the bucket definition", equality) SchemaUtils.checkColumnNameDuplication( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index ce851c58cc4f..e0b1c5ac85bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -29,7 +29,6 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.{OrcUtils => _, _} import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct -import org.apache.orc.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession @@ -142,10 +141,11 @@ class OrcFileFormat val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val resultedColPruneInfo = + val (resultedColPruneInfo, isOldOrcFile) = Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => - OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, reader, conf) + (OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, reader, conf), + OrcUtils.isOldOrcFile(reader.getSchema)) } if (resultedColPruneInfo.isEmpty) { @@ -155,7 +155,7 @@ class OrcFileFormat if (orcFilterPushDown && filters.nonEmpty) { OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) + mapreduce.OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) } } } @@ -174,7 +174,7 @@ class OrcFileFormat val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) if (enableVectorizedReader) { - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, !isOldOrcFile) // SPARK-23399 Register a task completion listener first to call `close()` in all cases. // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) // after opening a file. @@ -193,8 +193,8 @@ class OrcFileFormat iter.asInstanceOf[Iterator[InternalRow]] } else { - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) + val orcRecordReader = + OrcUtils.createRecordReader[OrcStruct](fileSplit, taskAttemptContext, !isOldOrcFile) val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index fe057e0ddfc4..1d629970d825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -45,6 +45,7 @@ private[sql] class OrcOutputWriter( val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc") val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration) options.setSchema(OrcUtils.orcTypeDescription(dataSchema)) + options.useUTCTimestamp(true) val writer = OrcFile.createWriter(filename, options) val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer) OrcUtils.addSparkVersionMetadata(writer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index d1b7e8db619b..8be02b5eeca6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -28,7 +28,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} -import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.hadoop.mapreduce.{InputSplit, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.orc.{mapreduce, BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} import org.apache.orc.mapred.OrcTimestamp import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} @@ -142,6 +144,50 @@ object OrcUtils extends Logging { CharVarcharUtils.replaceCharVarcharWithStringInSchema(toStructType(schema)) } + /** + * Judge the Orc file be read is write by Spark 3.1 or prior. + */ + def isOldOrcFile(schema: TypeDescription): Boolean = { + import TypeDescription.Category + + def find(orcType: TypeDescription): Boolean = { + orcType.getCategory match { + case Category.STRUCT => findInStruct(orcType) + case Category.LIST => findInArray(orcType) + case Category.MAP => findInMap(orcType) + case Category.TIMESTAMP => + if (orcType.getAttributeValue(CATALYST_TYPE_ATTRIBUTE_NAME) == null) { + true + } else { + false + } + case _ => false + } + } + + def findInStruct(orcType: TypeDescription): Boolean = { + val fieldTypes = orcType.getChildren.asScala + for (fieldType <- fieldTypes) { + if (find(fieldType)) { + return true + } + } + false + } + + def findInArray(orcType: TypeDescription): Boolean = { + val elementType = orcType.getChildren.get(0) + find(elementType) + } + + def findInMap(orcType: TypeDescription): Boolean = { + val Seq(keyType, valueType) = orcType.getChildren.asScala.toSeq + find(keyType) || find(valueType) + } + + find(schema) + } + def readSchema(sparkSession: SparkSession, files: Seq[FileStatus], options: Map[String, String]) : Option[StructType] = { val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles @@ -542,4 +588,24 @@ object OrcUtils extends Logging { result.setNanos(nanos.toInt) result } + + /** + * This method references createRecordReader of OrcInputFormat. + * Just for call useUTCTimestamp of OrcFile.ReaderOptions. + * + * @return OrcMapreduceRecordReader + */ + def createRecordReader[V <: WritableComparable[_]]( + inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext, + useUTCTimestamp: Boolean): mapreduce.OrcMapreduceRecordReader[V] = { + val split = inputSplit.asInstanceOf[FileSplit] + val conf = taskAttemptContext.getConfiguration() + val readOptions = OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)).useUTCTimestamp(useUTCTimestamp) + val file = OrcFile.createReader(split.getPath(), readOptions) + val options = org.apache.orc.mapred.OrcInputFormat.buildOptions( + conf, file, split.getStart(), split.getLength()).useSelected(true) + new mapreduce.OrcMapreduceRecordReader(file, options) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0e8efb629706..327d92672db8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} @@ -81,7 +82,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, // we fail the query if the partitioning information is specified. - case c @ CreateTable(tableDesc, _, None) if tableDesc.schema.isEmpty => + case c @ CreateTableV1(tableDesc, _, None) if tableDesc.schema.isEmpty => if (tableDesc.bucketSpec.isDefined) { failAnalysis("Cannot specify bucketing information if the table schema is not specified " + "when creating and will be inferred at runtime") @@ -96,7 +97,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // When we append data to an existing table, check if the given provider, partition columns, // bucket spec, etc. match the existing table, and adjust the columns order of the given query // if necessary. - case c @ CreateTable(tableDesc, SaveMode.Append, Some(query)) + case c @ CreateTableV1(tableDesc, SaveMode.Append, Some(query)) if query.resolved && catalog.tableExists(tableDesc.identifier) => // This is guaranteed by the parser and `DataFrameWriter` assert(tableDesc.provider.isDefined) @@ -189,7 +190,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // * partition columns' type must be AtomicType. // * sort columns' type must be orderable. // * reorder table schema or output of query plan, to put partition columns at the end. - case c @ CreateTable(tableDesc, _, query) if query.forall(_.resolved) => + case c @ CreateTableV1(tableDesc, _, query) if query.forall(_.resolved) => if (query.isDefined) { assert(tableDesc.schema.isEmpty, "Schema may not be specified in a Create Table As Select (CTAS) statement") @@ -433,7 +434,7 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { object HiveOnlyCheck extends (LogicalPlan => Unit) { def apply(plan: LogicalPlan): Unit = { plan.foreach { - case CreateTable(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => + case CreateTableV1(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => throw QueryCompilationErrors.ddlWithoutHiveSupportEnabledError( "CREATE Hive TABLE (AS SELECT)") case i: InsertIntoDir if DDLUtils.isHiveTable(i.provider) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala index be7331b0d7dc..6e5c3af4573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -22,7 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.catalyst.plans.logical.TableSpec +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType @@ -32,10 +33,18 @@ case class CreateTableExec( identifier: Identifier, tableSchema: StructType, partitioning: Seq[Transform], - tableProperties: Map[String, String], + tableSpec: TableSpec, ignoreIfExists: Boolean) extends LeafV2CommandExec { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val tableProperties = { + val props = CatalogV2Util.convertTableProperties( + tableSpec.properties, tableSpec.options, tableSpec.serde, + tableSpec.location, tableSpec.comment, tableSpec.provider, + tableSpec.external) + CatalogV2Util.withDefaultOwnership(props) + } + override protected def run(): Seq[InternalRow] = { if (!catalog.tableExists(identifier)) { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 026ff63608bb..f64c1ee001be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -165,9 +165,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics) :: Nil - case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) => - val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) - CreateTableExec(catalog, ident, schema, parts, propsWithOwner, ifNotExists) :: Nil + case CreateTable(ResolvedDBObjectName(catalog, ident), schema, partitioning, + tableSpec, ifNotExists) => + CreateTableExec(catalog.asTableCatalog, ident.asIdentifier, schema, + partitioning, tableSpec, ifNotExists) :: Nil case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala index f7d79a1259ea..3be9b5c5471a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.CharVarcharUtils case class DescribeColumnExec( override val output: Seq[Attribute], @@ -37,7 +38,8 @@ case class DescribeColumnExec( } rows += toCatalystRow("col_name", column.name) - rows += toCatalystRow("data_type", column.dataType.catalogString) + rows += toCatalystRow("data_type", + CharVarcharUtils.getRawType(column.metadata).getOrElse(column.dataType).catalogString) rows += toCatalystRow("comment", comment) // TODO: The extended description (isExtended = true) can be added here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 8b0328cabc5a..21503fda53e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -136,10 +136,10 @@ trait FileScan extends Scan val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap val readPartitionAttributes = readPartitionSchema.map { readField => - attributeMap.get(normalizeName(readField.name)).getOrElse { + attributeMap.getOrElse(normalizeName(readField.name), throw QueryCompilationErrors.cannotFindPartitionColumnInPartitionSchemaError( readField, fileIndex.partitionSchema) - } + ) } lazy val partitionValueProject = GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index ec6a3bbc2618..816ced133e49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -88,10 +88,11 @@ case class OrcPartitionReaderFactory( } val filePath = new Path(new URI(file.filePath)) - val resultedColPruneInfo = + val (resultedColPruneInfo, isOldOrcFile) = Utils.tryWithResource(createORCReader(filePath, conf)) { reader => - OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, readDataSchema, reader, conf) + (OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, readDataSchema, reader, conf), + OrcUtils.isOldOrcFile(reader.getSchema)) } if (resultedColPruneInfo.isEmpty) { @@ -108,8 +109,8 @@ case class OrcPartitionReaderFactory( val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) + val orcRecordReader = + OrcUtils.createRecordReader[OrcStruct](fileSplit, taskAttemptContext, !isOldOrcFile) val deserializer = new OrcDeserializer(readDataSchema, requestedColIds) val fileReader = new PartitionReader[InternalRow] { override def next(): Boolean = orcRecordReader.nextKeyValue() @@ -131,10 +132,11 @@ case class OrcPartitionReaderFactory( } val filePath = new Path(new URI(file.filePath)) - val resultedColPruneInfo = + val (resultedColPruneInfo, isOldOrcFile) = Utils.tryWithResource(createORCReader(filePath, conf)) { reader => - OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, readDataSchema, reader, conf) + (OrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, readDataSchema, reader, conf), + OrcUtils.isOldOrcFile(reader.getSchema)) } if (resultedColPruneInfo.isEmpty) { @@ -152,7 +154,7 @@ case class OrcPartitionReaderFactory( val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - val batchReader = new OrcColumnarBatchReader(capacity) + val batchReader = new OrcColumnarBatchReader(capacity, !isOldOrcFile) batchReader.initialize(fileSplit, taskAttemptContext) val requestedPartitionColIds = Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index a2b33c2ba303..c88e6ae3f477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -116,7 +116,7 @@ private[sql] class RocksDBStateStoreProvider rocksDBMetrics.nativeOpsHistograms.get(typ).map(_.count).getOrElse(0) } def nativeOpsMetrics(typ: String): Long = { - rocksDBMetrics.nativeOpsMetrics.get(typ).getOrElse(0) + rocksDBMetrics.nativeOpsMetrics.getOrElse(typ, 0) } val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index b15c70a7eba7..b8575b052b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -81,7 +81,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging summary ++ planVisualization(request, metrics, graph) ++ - physicalPlanDescription(executionUIData.physicalPlanDescription) + physicalPlanDescription(executionUIData.physicalPlanDescription) ++ + modifiedConfigs(executionUIData.modifiedConfigs) }.getOrElse {

No information to display for query {executionId}
} @@ -145,4 +146,28 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
} + + private def modifiedConfigs(modifiedConfigs: Map[String, String]): Seq[Node] = { + val configs = UIUtils.listingTable( + propertyHeader, + propertyRow, + modifiedConfigs.toSeq.sorted, + fixedWidth = true + ) + +
+ + + SQL Properties + + +
+
+ } + + private def propertyHeader = Seq("Name", "Value") + private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e7ab4a184b07..d892dbdc2316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -93,6 +93,7 @@ class SQLAppStatusListener( executionData.description = sqlStoreData.description executionData.details = sqlStoreData.details executionData.physicalPlanDescription = sqlStoreData.physicalPlanDescription + executionData.modifiedConfigs = sqlStoreData.modifiedConfigs executionData.metrics = sqlStoreData.metrics executionData.submissionTime = sqlStoreData.submissionTime executionData.completionTime = sqlStoreData.completionTime @@ -336,7 +337,7 @@ class SQLAppStatusListener( private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) = event + physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs) = event val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => @@ -353,6 +354,7 @@ class SQLAppStatusListener( exec.description = description exec.details = details exec.physicalPlanDescription = physicalPlanDescription + exec.modifiedConfigs = modifiedConfigs exec.metrics = sqlPlanMetrics exec.submissionTime = time update(exec) @@ -479,6 +481,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { var description: String = null var details: String = null var physicalPlanDescription: String = null + var modifiedConfigs: Map[String, String] = _ var metrics = Seq[SQLPlanMetric]() var submissionTime = -1L var completionTime: Option[Date] = None @@ -499,6 +502,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { description, details, physicalPlanDescription, + modifiedConfigs, metrics, submissionTime, completionTime, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index a90f37a80d52..7c3315e3d76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -86,6 +86,7 @@ class SQLExecutionUIData( val description: String, val details: String, val physicalPlanDescription: String, + val modifiedConfigs: Map[String, String], val metrics: Seq[SQLPlanMetric], val submissionTime: Long, val completionTime: Option[Date], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 6a6a71c46f21..26805e135b77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,7 +47,8 @@ case class SparkListenerSQLExecutionStart( details: String, physicalPlanDescription: String, sparkPlanInfo: SparkPlanInfo, - time: Long) + time: Long, + modifiedConfigs: Map[String, String] = Map.empty) extends SparkListenerEvent @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 10ce9d3aaf01..2d3c89874f59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -27,8 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedDBObjectName import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.CreateTableStatement +import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, TableSpec} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback} @@ -288,10 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * Note, currently the new table creation by this API doesn't fully cover the V2 table. * TODO (SPARK-33638): Full support of v2 table creation */ - val cmd = CreateTableStatement( - originalMultipartIdentifier, - df.schema.asNullable, - partitioningColumns.getOrElse(Nil).asTransforms.toSeq, + val tableProperties = TableSpec( None, Map.empty[String, String], Some(source), @@ -299,8 +297,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), None, None, - external = false, - ifNotExists = false) + false) + val cmd = CreateTable( + UnresolvedDBObjectName( + originalMultipartIdentifier, + isNamespace = false), + df.schema.asNullable, + partitioningColumns.getOrElse(Nil).asTransforms.toSeq, + tableProperties, + ignoreIfExists = false) Dataset.ofRows(df.sparkSession, cmd) } diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index b95c8dac9a82..c3c09778a228 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -230,7 +230,8 @@ select next_day(timestamp_ntz"2015-07-23 12:12:12", "Mon") struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'next_day(TIMESTAMP_NTZ '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2015-07-23 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +cannot resolve 'next_day(TIMESTAMP_NTZ '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2015-07-23 12:12:12'' is of timestamp_ntz type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -498,7 +499,8 @@ select date_add(date_str, 1) from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve 'date_add(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -507,7 +509,8 @@ select date_sub(date_str, 1) from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve 'date_sub(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -589,7 +592,8 @@ select date_str - date '2001-09-28' from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(date_view.date_str - DATE '2001-09-28')' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve '(date_view.date_str - DATE '2001-09-28')' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index e9c323254b4a..230393f02ac3 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1533,7 +1533,8 @@ select str - interval '4 22:12' day to minute from interval_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + (- INTERVAL '4 22:12' DAY TO MINUTE)' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type.; line 1 pos 7 +cannot resolve 'interval_view.str + (- INTERVAL '4 22:12' DAY TO MINUTE)' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -1542,7 +1543,8 @@ select str + interval '4 22:12' day to minute from interval_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + INTERVAL '4 22:12' DAY TO MINUTE' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type.; line 1 pos 7 +cannot resolve 'interval_view.str + INTERVAL '4 22:12' DAY TO MINUTE' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out index 13f3fe064aef..84dcf3aca7ca 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out @@ -686,6 +686,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException Union can only be performed on tables with the compatible column types. The first column of the second table is string type which is not compatible with decimal(38,18) at same column of first table +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 7be54d49a90e..f2df9af9ed8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -843,17 +843,6 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } } - // TODO(SPARK-33875): Move these tests to super after DESCRIBE COLUMN v2 implemented - test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { - withTable("t") { - sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") - checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), - Row("varchar(3)")) - checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), - Row("char(5)")) - } - } - // TODO(SPARK-33898): Move these tests to super after SHOW CREATE TABLE for v2 implemented test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") { withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index a090eba43006..76b3324e3e1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -27,11 +27,11 @@ import org.scalatest.Assertions._ import org.apache.spark.TestUtils import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and @@ -218,6 +218,29 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String } + class PythonUDFWithoutId( + name: String, + func: PythonFunction, + dataType: DataType, + children: Seq[Expression], + evalType: Int, + udfDeterministic: Boolean, + resultId: ExprId) + extends PythonUDF(name, func, dataType, children, evalType, udfDeterministic, resultId) { + + def this(pudf: PythonUDF) = { + this(pudf.name, pudf.func, pudf.dataType, pudf.children, + pudf.evalType, pudf.udfDeterministic, pudf.resultId) + } + + override def toString: String = s"$name(${children.mkString(", ")})" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PythonUDFWithoutId = { + new PythonUDFWithoutId(super.withNewChildrenInternal(newChildren)) + } + } + /** * A Python UDF that takes one column, casts into string, executes the Python native function, * and casts back to the type of input column. @@ -253,7 +276,9 @@ object IntegratedUDFTestUtils extends SQLHelper { val expr = e.head assert(expr.resolved, "column should be resolved to use the same type " + "as input. Try df(name) or df.col(name)") - Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + val pythonUDF = new PythonUDFWithoutId( + super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) + Cast(pythonUDF, expr.dataType) } } @@ -297,7 +322,9 @@ object IntegratedUDFTestUtils extends SQLHelper { val expr = e.head assert(expr.resolved, "column should be resolved to use the same type " + "as input. Try df(name) or df.col(name)") - Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + val pythonUDF = new PythonUDFWithoutId( + super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) + Cast(pythonUDF, expr.dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index c33f47eb1f67..47691063b10b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2911,6 +2911,36 @@ class DataSourceV2SQLSuite } } + test("Check HasPartitionKey from InMemoryPartitionTable") { + val t = "testpart.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id string) USING foo PARTITIONED BY (key int)") + val table = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array(), "tbl")) + .asInstanceOf[InMemoryPartitionTable] + + sql(s"INSERT INTO $t VALUES ('a', 1), ('b', 2), ('c', 3)") + var partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 3) + assert(partKeys.toSet == Set(1, 2, 3)) + + sql(s"ALTER TABLE $t DROP PARTITION (key=3)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 2) + assert(partKeys.toSet == Set(1, 2)) + + sql(s"ALTER TABLE $t ADD PARTITION (key=4)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 3) + assert(partKeys.toSet == Set(1, 2, 4)) + + sql(s"INSERT INTO $t VALUES ('c', 3), ('e', 5)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 5) + assert(partKeys.toSet == Set(1, 2, 3, 4, 5)) + } + } + test("time travel") { sql("use testcat") // The testing in-memory table simply append the version/timestamp to the table name when @@ -2956,6 +2986,8 @@ class DataSourceV2SQLSuite === Array(Row(7), Row(8))) assert(sql("SELECT * FROM t TIMESTAMP AS OF make_date(2021, 1, 29)").collect === Array(Row(7), Row(8))) + assert(sql("SELECT * FROM t TIMESTAMP AS OF to_timestamp('2021-01-29 00:00:00')").collect + === Array(Row(7), Row(8))) val e1 = intercept[AnalysisException]( sql("SELECT * FROM t TIMESTAMP AS OF INTERVAL 1 DAY").collect() @@ -2968,9 +3000,19 @@ class DataSourceV2SQLSuite assert(e2.message.contains("is not a valid timestamp expression for time travel")) val e3 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF current_user()").collect() + ) + assert(e3.message.contains("is not a valid timestamp expression for time travel")) + + val e4 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF CAST(rand() AS STRING)").collect() + ) + assert(e4.message.contains("is not a valid timestamp expression for time travel")) + + val e5 = intercept[AnalysisException]( sql("SELECT * FROM t TIMESTAMP AS OF abs(true)").collect() ) - assert(e3.message.contains("cannot resolve 'abs(true)' due to data type mismatch")) + assert(e5.message.contains("cannot resolve 'abs(true)' due to data type mismatch")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 81e692076b43..740c10f17b26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.execution +import java.util.Locale import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicInteger import scala.collection.parallel.immutable.ParRange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart import org.apache.spark.sql.types._ import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT class SQLExecutionSuite extends SparkFunSuite { @@ -157,6 +161,45 @@ class SQLExecutionSuite extends SparkFunSuite { } } } + + test("SPARK-34735: Add modified configs for SQL execution in UI") { + val spark = SparkSession.builder() + .master("local[*]") + .appName("test") + .config("k1", "v1") + .getOrCreate() + + try { + val index = new AtomicInteger(0) + spark.sparkContext.addSparkListener(new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case start: SparkListenerSQLExecutionStart => + if (index.get() == 0 && hasProject(start)) { + assert(!start.modifiedConfigs.contains("k1")) + index.incrementAndGet() + } else if (index.get() == 1 && hasProject(start)) { + assert(start.modifiedConfigs.contains("k2")) + assert(start.modifiedConfigs("k2") == "v2") + assert(start.modifiedConfigs.contains("redaction.password")) + assert(start.modifiedConfigs("redaction.password") == REDACTION_REPLACEMENT_TEXT) + index.incrementAndGet() + } + case _ => + } + + private def hasProject(start: SparkListenerSQLExecutionStart): Boolean = + start.physicalPlanDescription.toLowerCase(Locale.ROOT).contains("project") + }) + spark.sql("SELECT 1").collect() + spark.sql("SET k2 = v2") + spark.sql("SET redaction.password = 123") + spark.sql("SELECT 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(index.get() == 2) + } finally { + spark.stop() + } + } } object SQLExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index 08789e63fa7f..55f171342249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.SparkListenerEvent import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.sql.test.TestSparkSession @@ -28,28 +29,46 @@ import org.apache.spark.util.JsonProtocol class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { test("SparkPlanGraph backward compatibility: metadata") { - val SQLExecutionStartJsonString = - """ - |{ - | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", - | "executionId":0, - | "description":"test desc", - | "details":"test detail", - | "physicalPlanDescription":"test plan", - | "sparkPlanInfo": { - | "nodeName":"TestNode", - | "simpleString":"test string", - | "children":[], - | "metadata":{}, - | "metrics":[] - | }, - | "time":0 - |} + Seq(true, false).foreach { newExecutionStartEvent => + val event = if (newExecutionStartEvent) { + "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" + } else { + "org.apache.spark.sql.execution.OldVersionSQLExecutionStart" + } + val SQLExecutionStartJsonString = + s""" + |{ + | "Event":"$event", + | "executionId":0, + | "description":"test desc", + | "details":"test detail", + | "physicalPlanDescription":"test plan", + | "sparkPlanInfo": { + | "nodeName":"TestNode", + | "simpleString":"test string", + | "children":[], + | "metadata":{}, + | "metrics":[] + | }, + | "time":0, + | "modifiedConfigs": { + | "k1":"v1" + | } + |} """.stripMargin - val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) - val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", - new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) - assert(reconstructedEvent == expectedEvent) + + val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) + if (newExecutionStartEvent) { + val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", + "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0, + Map("k1" -> "v1")) + assert(reconstructedEvent == expectedEvent) + } else { + val expectedOldEvent = OldVersionSQLExecutionStart(0, "test desc", "test detail", + "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) + assert(reconstructedEvent == expectedOldEvent) + } + } } test("SparkListenerSQLExecutionEnd backward compatibility") { @@ -77,3 +96,12 @@ class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { assert(readBack == event) } } + +private case class OldVersionSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 0fc43c7052d0..0e9e9a706027 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -119,31 +119,36 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + val query = dataType match { + case BooleanType => "sum(cast(id as bigint))" + case _ => "sum(id)" + } + sqlBenchmark.addCase("SQL CSV") { _ => - spark.sql("select sum(id) from csvTable").noop() + spark.sql(s"select $query from csvTable").noop() } sqlBenchmark.addCase("SQL Json") { _ => - spark.sql("select sum(id) from jsonTable").noop() + spark.sql(s"select $query from jsonTable").noop() } sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql("select sum(id) from parquetTable").noop() + spark.sql(s"select $query from parquetTable").noop() } sqlBenchmark.addCase("SQL Parquet MR") { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from parquetTable").noop() + spark.sql(s"select $query from parquetTable").noop() } } sqlBenchmark.addCase("SQL ORC Vectorized") { _ => - spark.sql("SELECT sum(id) FROM orcTable").noop() + spark.sql(s"SELECT $query FROM orcTable").noop() } sqlBenchmark.addCase("SQL ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("SELECT sum(id) FROM orcTable").noop() + spark.sql(s"SELECT $query FROM orcTable").noop() } } @@ -157,6 +162,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { var longSum = 0L var doubleSum = 0.0 val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case BooleanType => (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) @@ -191,6 +197,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { var longSum = 0L var doubleSum = 0.0 val aggregateValue: (InternalRow) => Unit = dataType match { + case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L case ByteType => (col: InternalRow) => longSum += col.getByte(0) case ShortType => (col: InternalRow) => longSum += col.getShort(0) case IntegerType => (col: InternalRow) => longSum += col.getInt(0) @@ -542,7 +549,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("SQL Single Numeric Column Scan") { - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { + Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index 2aef62988fa2..0713e9be3f5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -150,6 +150,16 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { } } } + + test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { + withTable("t") { + sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") + checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), + Row("varchar(3)")) + checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), + Row("char(5)")) + } + } } class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 85ba14fc7a44..a6b979a3fd52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -26,17 +26,17 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedDBObjectName, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} -import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.sources.SimpleScanSource @@ -210,7 +210,7 @@ class PlanResolutionSuite extends AnalysisTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parseAndResolve(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + case CreateTableV1(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) }.head } @@ -240,7 +240,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(query) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -282,7 +282,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(query) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -302,7 +302,7 @@ class PlanResolutionSuite extends AnalysisTest { comment = Some("abc")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -322,7 +322,7 @@ class PlanResolutionSuite extends AnalysisTest { properties = Map("test" -> "test")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -341,7 +341,7 @@ class PlanResolutionSuite extends AnalysisTest { provider = Some("parquet")) parseAndResolve(v1) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -372,7 +372,7 @@ class PlanResolutionSuite extends AnalysisTest { provider = Some("parquet")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -398,7 +398,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -471,29 +471,20 @@ class PlanResolutionSuite extends AnalysisTest { |OPTIONS (path 's3://bucket/path/to/data', other 20) """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "option.other" -> "20", - "provider" -> "parquet", - "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment", - "other" -> "20") - parseAndResolve(sql) match { - case create: CreateV2Table => - assert(create.catalog.name == "testcat") - assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == "testcat") + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.table_name") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -511,29 +502,20 @@ class PlanResolutionSuite extends AnalysisTest { |OPTIONS (path 's3://bucket/path/to/data', other 20) """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "option.other" -> "20", - "provider" -> "parquet", - "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment", - "other" -> "20") - parseAndResolve(sql, withDefault = true) match { - case create: CreateV2Table => - assert(create.catalog.name == "testcat") - assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == "testcat") + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.table_name") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -551,27 +533,21 @@ class PlanResolutionSuite extends AnalysisTest { |TBLPROPERTIES ('p1'='v1', 'p2'='v2') """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "provider" -> v2Format, - "location" -> "/user/external/page_view", - "comment" -> "This is the staging page view table") - parseAndResolve(sql) match { - case create: CreateV2Table => - assert(create.catalog.name == CatalogManager.SESSION_CATALOG_NAME) - assert(create.tableName == Identifier.of(Array("mydb"), "page_view")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == + CatalogManager.SESSION_CATALOG_NAME) + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.page_view") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -1684,9 +1660,9 @@ class PlanResolutionSuite extends AnalysisTest { */ def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan match { - case CreateTable(tableDesc, mode, query) => + case CreateTableV1(tableDesc, mode, query) => val newTableDesc = tableDesc.copy(createTime = -1L) - CreateTable(newTableDesc, mode, query) + CreateTableV1(newTableDesc, mode, query) case _ => plan // Don't transform } } @@ -1707,8 +1683,8 @@ class PlanResolutionSuite extends AnalysisTest { partitionColumnNames: Seq[String] = Seq.empty, comment: Option[String] = None, mode: SaveMode = SaveMode.ErrorIfExists, - query: Option[LogicalPlan] = None): CreateTable = { - CreateTable( + query: Option[LogicalPlan] = None): CreateTableV1 = { + CreateTableV1( CatalogTable( identifier = TableIdentifier(table, database), tableType = tableType, @@ -1790,7 +1766,7 @@ class PlanResolutionSuite extends AnalysisTest { allSources.foreach { s => val query = s"CREATE TABLE my_tab STORED AS $s" parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == @@ -1809,14 +1785,14 @@ class PlanResolutionSuite extends AnalysisTest { // No conflicting serdes here, OK parseAndResolve(query1) match { - case parsed1: CreateTable => + case parsed1: CreateTableV1 => assert(parsed1.tableDesc.storage.serde == Some("anything")) assert(parsed1.tableDesc.storage.inputFormat == Some("inputfmt")) assert(parsed1.tableDesc.storage.outputFormat == Some("outputfmt")) } parseAndResolve(query2) match { - case parsed2: CreateTable => + case parsed2: CreateTableV1 => assert(parsed2.tableDesc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(parsed2.tableDesc.storage.inputFormat == Some("inputfmt")) @@ -1832,7 +1808,7 @@ class PlanResolutionSuite extends AnalysisTest { val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s" if (supportedSources.contains(s)) { parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == Some("anything")) @@ -1853,7 +1829,7 @@ class PlanResolutionSuite extends AnalysisTest { val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s" if (supportedSources.contains(s)) { parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == hiveSerde.get.serde @@ -1870,14 +1846,14 @@ class PlanResolutionSuite extends AnalysisTest { test("create hive external table") { val withoutLoc = "CREATE EXTERNAL TABLE my_tab STORED AS parquet" parseAndResolve(withoutLoc) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri.isEmpty) } val withLoc = "CREATE EXTERNAL TABLE my_tab STORED AS parquet LOCATION '/something/anything'" parseAndResolve(withLoc) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } @@ -1897,7 +1873,7 @@ class PlanResolutionSuite extends AnalysisTest { test("create hive table - location implies external") { val query = "CREATE TABLE my_tab STORED AS parquet LOCATION '/something/anything'" parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } @@ -2261,14 +2237,6 @@ class PlanResolutionSuite extends AnalysisTest { assert(e2.getMessage.contains("Operation not allowed")) } - test("create table - properties") { - val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - parsePlan(query) match { - case state: CreateTableStatement => - assert(state.properties == Map("k1" -> "v1", "k2" -> "v2")) - } - } - test("create table(hive) - everything!") { val query = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala index c9e5d33fea87..7c810671c5b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala @@ -19,52 +19,48 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ShowNamespaces import org.apache.spark.sql.test.SharedSparkSession class ShowNamespacesParserSuite extends AnalysisTest with SharedSparkSession { - test("all namespaces") { - Seq("SHOW NAMESPACES", "SHOW DATABASES").foreach { sqlCmd => + private val keywords = Seq("NAMESPACES", "DATABASES", "SCHEMAS") + + test("show namespaces in the current catalog") { + keywords.foreach { keyword => comparePlans( - parsePlan(sqlCmd), + parsePlan(s"SHOW $keyword"), ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), None)) } } - test("basic pattern") { - Seq( - "SHOW DATABASES LIKE 'defau*'", - "SHOW NAMESPACES LIKE 'defau*'").foreach { sqlCmd => + test("show namespaces with a pattern") { + keywords.foreach { keyword => comparePlans( - parsePlan(sqlCmd), + parsePlan(s"SHOW $keyword LIKE 'defau*'"), + ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), Some("defau*"))) + // LIKE can be omitted. + comparePlans( + parsePlan(s"SHOW $keyword 'defau*'"), ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), Some("defau*"))) - } - } - - test("FROM/IN operator is not allowed by SHOW DATABASES") { - Seq( - "SHOW DATABASES FROM testcat.ns1.ns2", - "SHOW DATABASES IN testcat.ns1.ns2").foreach { sqlCmd => - val errMsg = intercept[ParseException] { - parsePlan(sqlCmd) - }.getMessage - assert(errMsg.contains("FROM/IN operator is not allowed in SHOW DATABASES")) } } test("show namespaces in/from a namespace") { - comparePlans( - parsePlan("SHOW NAMESPACES FROM testcat.ns1.ns2"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) - comparePlans( - parsePlan("SHOW NAMESPACES IN testcat.ns1.ns2"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + keywords.foreach { keyword => + comparePlans( + parsePlan(s"SHOW $keyword FROM testcat.ns1.ns2"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + comparePlans( + parsePlan(s"SHOW $keyword IN testcat.ns1.ns2"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + } } test("namespaces by a pattern from another namespace") { - comparePlans( - parsePlan("SHOW NAMESPACES IN testcat.ns1 LIKE '*pattern*'"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1")), Some("*pattern*"))) + keywords.foreach { keyword => + comparePlans( + parsePlan(s"SHOW $keyword IN testcat.ns1 LIKE '*pattern*'"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1")), Some("*pattern*"))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala index 1b37444b14a0..b3693845c3b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala @@ -41,6 +41,7 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { } protected def builtinTopNamespaces: Seq[String] = Seq.empty + protected def isCasePreserving: Boolean = true test("default namespace") { withSQLConf(SQLConf.DEFAULT_CATALOG.key -> catalog) { @@ -51,7 +52,7 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { test("at the top level") { withNamespace(s"$catalog.ns1", s"$catalog.ns2") { - sql(s"CREATE DATABASE $catalog.ns1") + sql(s"CREATE NAMESPACE $catalog.ns1") sql(s"CREATE NAMESPACE $catalog.ns2") runShowNamespacesSql( @@ -64,24 +65,12 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { withNamespace(s"$catalog.ns1", s"$catalog.ns2") { sql(s"CREATE NAMESPACE $catalog.ns1") sql(s"CREATE NAMESPACE $catalog.ns2") - Seq( - s"SHOW NAMESPACES IN $catalog LIKE 'ns2'", - s"SHOW NAMESPACES IN $catalog 'ns2'", - s"SHOW NAMESPACES FROM $catalog LIKE 'ns2'", - s"SHOW NAMESPACES FROM $catalog 'ns2'").foreach { sqlCmd => - withClue(sqlCmd) { - runShowNamespacesSql(sqlCmd, Seq("ns2")) - } - } + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'ns2'", Seq("ns2")) } } test("does not match to any namespace") { - Seq( - "SHOW DATABASES LIKE 'non-existentdb'", - "SHOW NAMESPACES 'non-existentdb'").foreach { sqlCmd => - runShowNamespacesSql(sqlCmd, Seq.empty) - } + runShowNamespacesSql("SHOW NAMESPACES LIKE 'non-existentdb'", Seq.empty) } test("show root namespaces with the default catalog") { @@ -134,4 +123,23 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { assert(sql("SHOW NAMESPACES").schema.fieldNames.toSeq == Seq("databaseName")) } } + + test("case sensitivity of the pattern string") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withNamespace(s"$catalog.AAA", s"$catalog.bbb") { + sql(s"CREATE NAMESPACE $catalog.AAA") + sql(s"CREATE NAMESPACE $catalog.bbb") + // TODO: The v1 in-memory catalog should be case preserving as well. + val casePreserving = isCasePreserving && (catalogVersion == "V2" || caseSensitive) + val expected = if (casePreserving) "AAA" else "aaa" + runShowNamespacesSql( + s"SHOW NAMESPACES IN $catalog", + Seq(expected, "bbb") ++ builtinTopNamespaces) + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq(expected)) + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq(expected)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala index 54c5d2246495..a1b32e42ae2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.command -import org.apache.spark.sql.internal.SQLConf /** * This base suite contains unified tests for the `SHOW NAMESPACES` and `SHOW DATABASES` commands @@ -42,21 +41,4 @@ trait ShowNamespacesSuiteBase extends command.ShowNamespacesSuiteBase { class ShowNamespacesSuite extends ShowNamespacesSuiteBase with CommandSuiteBase { override def commandVersion: String = "V2" // There is only V2 variant of SHOW NAMESPACES. - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - val expected = if (caseSensitive) "AAA" else "aaa" - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq(expected, "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq(expected)) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq(expected)) - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index bafb6608c8e6..ded657edc61f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -53,20 +53,4 @@ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSu }.getMessage assert(errMsg.contains("does not support namespaces")) } - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq("AAA", "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq("AAA")) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq("AAA")) - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala index bfcef4633990..ceea844abbed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala @@ -56,7 +56,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { requestedDataColIds: Array[Int], requestedPartitionColIds: Array[Int], resultFields: Array[StructField]): OrcColumnarBatchReader = { - val reader = new OrcColumnarBatchReader(4096) + val reader = new OrcColumnarBatchReader(4096, true) reader.initBatch( orcFileSchema, resultFields, @@ -121,7 +121,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with SharedSparkSession { val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, file.length, Array.empty) val taskConf = sqlContext.sessionState.newHadoopConf() val orcFileSchema = TypeDescription.fromString(schema.simpleString) - val vectorizedReader = new OrcColumnarBatchReader(4096) + val vectorizedReader = new OrcColumnarBatchReader(4096, true) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 2d6978a81024..8300eea8f297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -21,6 +21,7 @@ import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.time.{LocalDateTime, ZoneOffset} +import java.util.TimeZone import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -776,7 +777,7 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } :+ (null, null) withOrcFile(data) { file => - withAllOrcReaders { + withAllNativeOrcReaders { checkAnswer(spark.read.orc(file), data.toDF().collect()) } } @@ -799,7 +800,7 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withTempPath { file => val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) df.write.orc(file.getCanonicalPath) - withAllOrcReaders { + withAllNativeOrcReaders { val msg = intercept[SparkException] { spark.read.schema(providedSchema).orc(file.getCanonicalPath).collect() }.getMessage @@ -825,11 +826,41 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withTempPath { file => val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) df.write.orc(file.getCanonicalPath) - withAllOrcReaders { + withAllNativeOrcReaders { checkAnswer(spark.read.schema(providedSchema).orc(file.getCanonicalPath), answer) } } } + + test("SPARK-37463: read/write Timestamp ntz or ltz to Orc uses UTC timestamp") { + val localTimeZone = TimeZone.getDefault + try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + + val sqlText = """ + |select + | timestamp_ntz '2021-06-01 00:00:00' ts_ntz1, + | timestamp_ntz '1883-11-16 00:00:00.0' as ts_ntz2, + | timestamp_ntz '2021-03-14 02:15:00.0' as ts_ntz3, + | timestamp_ntz'1996-10-27T09:10:25.088353' as ts_ntz4 + |""".stripMargin + + val df = sql(sqlText) + + df.write.mode("overwrite").orc("ts_ntz_orc") + + val query = "select * from `orc`.`ts_ntz_orc`" + + Seq("America/Los_Angeles", "UTC", "Europe/Amsterdam").foreach { tz => + TimeZone.setDefault(TimeZone.getTimeZone(tz)) + withAllNativeOrcReaders { + checkAnswer(sql(query), df) + } + } + } finally { + TimeZone.setDefault(localTimeZone) + } + } } class OrcV1QuerySuite extends OrcQuerySuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8ffccd9679c5..8953fbb372f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -485,12 +485,10 @@ abstract class OrcSuite } test("SPARK-31238: compatibility with Spark 2.4 in reading dates") { - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - readResourceOrcFile("test-data/before_1582_date_v2_4.snappy.orc"), - Row(java.sql.Date.valueOf("1200-01-01"))) - } + withAllNativeOrcReaders { + checkAnswer( + readResourceOrcFile("test-data/before_1582_date_v2_4.snappy.orc"), + Row(java.sql.Date.valueOf("1200-01-01"))) } } @@ -502,23 +500,19 @@ abstract class OrcSuite .write .orc(path) - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - spark.read.orc(path), - Seq(Row(Date.valueOf("1001-01-01")), Row(Date.valueOf("1582-10-15")))) - } + withAllNativeOrcReaders { + checkAnswer( + spark.read.orc(path), + Seq(Row(Date.valueOf("1001-01-01")), Row(Date.valueOf("1582-10-15")))) } } } test("SPARK-31284: compatibility with Spark 2.4 in reading timestamps") { - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - readResourceOrcFile("test-data/before_1582_ts_v2_4.snappy.orc"), - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) - } + withAllNativeOrcReaders { + checkAnswer( + readResourceOrcFile("test-data/before_1582_ts_v2_4.snappy.orc"), + Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) } } @@ -530,14 +524,12 @@ abstract class OrcSuite .write .orc(path) - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - spark.read.orc(path), - Seq( - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456")), - Row(java.sql.Timestamp.valueOf("1582-10-15 11:12:13.654321")))) - } + withAllNativeOrcReaders { + checkAnswer( + spark.read.orc(path), + Seq( + Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456")), + Row(java.sql.Timestamp.valueOf("1582-10-15 11:12:13.654321")))) } } } @@ -809,11 +801,12 @@ abstract class OrcSourceSuite extends OrcSuite with SharedSparkSession { } } - Seq(true, false).foreach { vecReaderEnabled => + withAllNativeOrcReaders { Seq(true, false).foreach { vecReaderNestedColEnabled => + val vecReaderEnabled = SQLConf.get.orcVectorizedReaderEnabled test("SPARK-36931: Support reading and writing ANSI intervals (" + - s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " + - s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)") { + s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " + + s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)") { withSQLConf( SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index cd87374e8574..96932de3275b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -143,7 +143,7 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor spark.read.orc(file.getAbsolutePath) } - def withAllOrcReaders(code: => Unit): Unit = { + def withAllNativeOrcReaders(code: => Unit): Unit = { // test the row-based reader withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false")(code) // test the vectorized reader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index 2317a4d00e06..79b8c9e2c571 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -29,14 +29,15 @@ import org.apache.spark.sql.test.SharedSparkSession class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSession { import testImplicits._ - val ROW = ((1).toByte, 2, 3L, "abc", Period.of(1, 1, 0), Duration.ofMillis(100)) + val ROW = ((1).toByte, 2, 3L, "abc", Period.of(1, 1, 0), Duration.ofMillis(100), true) val NULL_ROW = ( null.asInstanceOf[java.lang.Byte], null.asInstanceOf[Integer], null.asInstanceOf[java.lang.Long], null.asInstanceOf[String], null.asInstanceOf[Period], - null.asInstanceOf[Duration]) + null.asInstanceOf[Duration], + null.asInstanceOf[java.lang.Boolean]) test("All Types Dictionary") { (1 :: 1000 :: Nil).foreach { n => { @@ -59,6 +60,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess assert(batch.column(3).getUTF8String(i).toString == "abc") assert(batch.column(4).getInt(i) == 13) assert(batch.column(5).getLong(i) == 100000) + assert(batch.column(6).getBoolean(i) == true) i += 1 } reader.close() @@ -88,6 +90,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess assert(batch.column(3).isNullAt(i)) assert(batch.column(4).isNullAt(i)) assert(batch.column(5).isNullAt(i)) + assert(batch.column(6).isNullAt(i)) i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f12e5af9d43f..0966319f53fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -145,7 +145,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val numRecords = 100 val writer = createParquetWriter(schema, tablePath, dictionaryEnabled = dictEnabled) - (0 until numRecords).map { i => + (0 until numRecords).foreach { i => val record = new SimpleGroup(schema) for (group <- Seq(0, 2, 4)) { record.add(group, 1000L) // millis diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala index 5f3d750e8f27..090c149886a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala @@ -58,7 +58,7 @@ class SQLEventFilterBuilderSuite extends SparkFunSuite { // Start SQL Execution listener.onOtherEvent(SparkListenerSQLExecutionStart(1, "desc1", "details1", "plan", - new SparkPlanInfo("node", "str", Seq.empty, Map.empty, Seq.empty), time)) + new SparkPlanInfo("node", "str", Seq.empty, Map.empty, Seq.empty), time, Map.empty)) time += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala index 46fdaba413c6..724df8ebe8bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala @@ -42,7 +42,7 @@ class SQLLiveEntitiesEventFilterSuite extends SparkFunSuite { // Verifying with finished SQL execution 1 assert(Some(false) === acceptFn(SparkListenerSQLExecutionStart(1, "description1", "details1", - "plan", null, 0))) + "plan", null, 0, Map.empty))) assert(Some(false) === acceptFn(SparkListenerSQLExecutionEnd(1, 0))) assert(Some(false) === acceptFn(SparkListenerSQLAdaptiveExecutionUpdate(1, "plan", null))) assert(Some(false) === acceptFn(SparkListenerDriverAccumUpdates(1, Seq.empty))) @@ -89,7 +89,7 @@ class SQLLiveEntitiesEventFilterSuite extends SparkFunSuite { // Verifying with live SQL execution 2 assert(Some(true) === acceptFn(SparkListenerSQLExecutionStart(2, "description2", "details2", - "plan", null, 0))) + "plan", null, 0, Map.empty))) assert(Some(true) === acceptFn(SparkListenerSQLExecutionEnd(2, 0))) assert(Some(true) === acceptFn(SparkListenerSQLAdaptiveExecutionUpdate(2, "plan", null))) assert(Some(true) === acceptFn(SparkListenerDriverAccumUpdates(2, Seq.empty))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala index 24b8a973ade3..1f5cbb0e19ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala @@ -112,7 +112,8 @@ class AllExecutionsPageSuite extends SharedSparkSession with BeforeAndAfter { "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala index 533d98da240f..aa3988ae37e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -79,7 +79,8 @@ object MetricsAggregationBenchmark extends BenchmarkBase { getClass().getName(), getClass().getName(), planInfo, - System.currentTimeMillis()) + System.currentTimeMillis(), + Map.empty) val executionEnd = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e776a4ac23f7..61230641ded8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -198,7 +198,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, @@ -345,7 +346,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLExecutionStart(_, _, _, planDescription, _, _) => + case SparkListenerSQLExecutionStart(_, _, _, planDescription, _, _, _) => assert(expected.forall(planDescription.contains)) checkDone = true case _ => // ignore other events @@ -387,7 +388,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -416,7 +418,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -456,7 +459,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -485,7 +489,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( @@ -515,7 +520,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) var stageId = 0 def twoStageJob(jobId: Int): Unit = { @@ -654,7 +660,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) time += 1 listener.onOtherEvent(SparkListenerSQLExecutionStart( 2, @@ -662,7 +669,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) // Stop execution 2 before execution 1 time += 1 @@ -678,7 +686,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) assert(statusStore.executionsCount === 2) assert(statusStore.execution(2) === None) } @@ -713,7 +722,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, oldPlan, - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0477b41942d4..738f2281c9a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,9 +25,11 @@ import java.util.NoSuchElementException import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.apache.arrow.vector.IntVector +import org.apache.parquet.bytes.ByteBufferInputStream import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode @@ -36,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.parquet.VectorizedPlainValuesReader import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnarBatchRow, ColumnVector} @@ -130,6 +133,97 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + testVector("Boolean APIs", 1024, BooleanType) { + column => + val reference = mutable.ArrayBuffer.empty[Boolean] + + var values = Array(true, false, true, false, false) + var bits = values.foldRight(0)((b, i) => i << 1 | (if (b) 1 else 0)).toByte + column.appendBooleans(2, bits, 0) + reference ++= values.slice(0, 2) + + column.appendBooleans(3, bits, 2) + reference ++= values.slice(2, 5) + + column.appendBooleans(6, true) + reference ++= Array.fill(6)(true) + + column.appendBoolean(false) + reference += false + + var idx = column.elementsAppended + + values = Array(true, true, false, true, false, true, false, true) + bits = values.foldRight(0)((b, i) => i << 1 | (if (b) 1 else 0)).toByte + column.putBooleans(idx, 2, bits, 0) + reference ++= values.slice(0, 2) + idx += 2 + + column.putBooleans(idx, 3, bits, 2) + reference ++= values.slice(2, 5) + idx += 3 + + column.putBooleans(idx, bits) + reference ++= values + idx += 8 + + column.putBoolean(idx, false) + reference += false + idx += 1 + + column.putBooleans(idx, 3, true) + reference ++= Array.fill(3)(true) + idx += 3 + + implicit def intToByte(i: Int): Byte = i.toByte + val buf = ByteBuffer.wrap(Array(0x33, 0x5A, 0xA5, 0xCC, 0x0F, 0xF0, 0xEE, 0x77, 0x88)) + val reader = new VectorizedPlainValuesReader() + reader.initFromPage(0, ByteBufferInputStream.wrap(buf)) + + reader.skipBooleans(1) // bit index 0 + + column.putBoolean(idx, reader.readBoolean) // bit index 1 + reference += true + idx += 1 + + column.putBoolean(idx, reader.readBoolean) // bit index 2 + reference += false + idx += 1 + + reader.skipBooleans(5) // bit index [3, 7] + + column.putBoolean(idx, reader.readBoolean) // bit index 8 + reference += false + idx += 1 + + reader.skipBooleans(8) // bit index [9, 16] + reader.skipBooleans(0) // no-op + + column.putBoolean(idx, reader.readBoolean) // bit index 17 + reference += false + idx += 1 + + reader.skipBooleans(16) // bit index [18, 33] + + reader.readBooleans(4, column, idx) // bit index [34, 37] + reference ++= Array(true, true, false, false) + idx += 4 + + reader.readBooleans(11, column, idx) // bit index [38, 48] + reference ++= Array(false, false, false, false, false, false, true, true, true, true, false) + idx += 11 + + reader.skipBooleans(7) // bit index [49, 55] + + reader.readBooleans(9, column, idx) // bit index [56, 64] + reference ++= Array(true, true, true, false, true, true, true, false, false) + idx += 9 + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getBoolean(v._2), "VectorType=" + column.getClass.getSimpleName) + } + } + testVector("Byte APIs", 1024, ByteType) { column => val reference = mutable.ArrayBuffer.empty[Byte] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 21a0b24cb425..54bed5c966d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -1229,7 +1229,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi q.processAllAvailable() q } finally { - spark.streams.active.map(_.stop()) + spark.streams.active.foreach(_.stop()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala index dbc33c47fed5..baa04ada8b5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala @@ -85,6 +85,7 @@ object SqlResourceSuite { description = DESCRIPTION, details = "", physicalPlanDescription = PLAN_DESCRIPTION, + Map.empty, metrics = metrics, submissionTime = 1586768888233L, completionTime = Some(new Date(1586768888999L)), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 179b424fefb2..5fccce2678f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient @@ -436,8 +436,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val properties = new mutable.HashMap[String, String] properties.put(CREATED_SPARK_VERSION, table.createVersion) + // This is for backward compatibility to Spark 2 to read tables with char/varchar created by + // Spark 3.1. At read side, we will restore a table schema from its properties. So, we need to + // clear the `varchar(n)` and `char(n)` and replace them with `string` as Spark 2 does not have + // a type mapping for them in `DataType.nameToType`. + // See `restoreHiveSerdeTable` for example. + val newSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) CatalogTable.splitLargeTableProp( - DATASOURCE_SCHEMA, schema.json, properties.put, conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)) + DATASOURCE_SCHEMA, + newSchema.json, + properties.put, + conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)) if (partitionColumns.nonEmpty) { properties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString) @@ -742,8 +751,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat case None if table.tableType == VIEW => // If this is a view created by Spark 2.2 or higher versions, we should restore its schema // from table properties. - CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA).foreach { schemaJson => - table = table.copy(schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]) + getSchemaFromTableProperties(table.properties).foreach { schemaFromTableProps => + table = table.copy(schema = schemaFromTableProps) } // No provider in table properties, which means this is a Hive serde table. @@ -793,9 +802,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // If this is a Hive serde table created by Spark 2.1 or higher versions, we should restore its // schema from table properties. - val schemaJson = CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA) - if (schemaJson.isDefined) { - val schemaFromTableProps = DataType.fromJson(schemaJson.get).asInstanceOf[StructType] + val maybeSchemaFromTableProps = getSchemaFromTableProperties(table.properties) + if (maybeSchemaFromTableProps.isDefined) { + val schemaFromTableProps = maybeSchemaFromTableProps.get val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) @@ -821,6 +830,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + private def getSchemaFromTableProperties( + tableProperties: Map[String, String]): Option[StructType] = { + CatalogTable.readLargeTableProp(tableProperties, DATASOURCE_SCHEMA).map { schemaJson => + val parsed = DataType.fromJson(schemaJson).asInstanceOf[StructType] + CharVarcharUtils.getRawSchema(parsed) + } + } + private def restoreDataSourceTable(table: CatalogTable, provider: String): CatalogTable = { // Internally we store the table location in storage properties with key "path" for data // source tables. Here we set the table location to `locationUri` field and filter out the @@ -835,8 +852,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_)).toMap) val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) - val schemaFromTableProps = CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA) - .map(json => DataType.fromJson(json).asInstanceOf[StructType]).getOrElse(new StructType()) + val schemaFromTableProps = + getSchemaFromTableProperties(table.properties).getOrElse(new StructType()) val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala index 015001fa4f78..2f7303c42c98 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution.command import org.apache.spark.sql.execution.command.v1 -import org.apache.spark.sql.internal.SQLConf /** * The class contains tests for the `SHOW NAMESPACES` and `SHOW DATABASES` commands to check @@ -26,22 +25,8 @@ import org.apache.spark.sql.internal.SQLConf */ class ShowNamespacesSuite extends v1.ShowNamespacesSuiteBase with CommandSuiteBase { override def commandVersion: String = "V2" // There is only V2 variant of SHOW NAMESPACES. - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq("aaa", "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq("aaa")) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq("aaa")) - } - } - } - } + // Hive Catalog is not case preserving and always lower-case the namespace name when storing it. + override def isCasePreserving: Boolean = false test("hive client calls") { withNamespace(s"$catalog.ns1", s"$catalog.ns2") {