diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f121b62a53d2..f301606933a9 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,4 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C spark-submit2.cmd %* +cmd /V /E /C "%~dp0spark-submit2.cmd" %* diff --git a/tags/README.md b/common/tags/README.md similarity index 100% rename from tags/README.md rename to common/tags/README.md diff --git a/tags/pom.xml b/common/tags/pom.xml similarity index 97% rename from tags/pom.xml rename to common/tags/pom.xml index 3e8e6f618287..8e702b4fefe8 100644 --- a/tags/pom.xml +++ b/common/tags/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/unsafe/pom.xml b/common/unsafe/pom.xml similarity index 98% rename from unsafe/pom.xml rename to common/unsafe/pom.xml index 75fea556eeae..5250014739da 100644 --- a/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -23,7 +23,7 @@ org.apache.spark spark-parent_2.11 2.0.0-SNAPSHOT - ../pom.xml + ../../pom.xml org.apache.spark diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/Platform.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java similarity index 100% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java similarity index 100% rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala similarity index 100% rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 6195916195e3..167c8020850d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -149,7 +149,8 @@ $(document).ready(function() { {name: 'seventh'}, {name: 'eighth'}, ], - "autoWidth": false + "autoWidth": false, + "order": [[ 0, "desc" ]] }; var rowGroupConf = { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a1fa266e183e..0e8b735b923b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -244,7 +244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def eventLogDir: Option[URI] = _eventLogDir private[spark] def eventLogCodec: Option[String] = _eventLogCodec - def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + def isLocal: Boolean = Utils.isLocalMaster(_conf) /** * @return true if context is stopped or in the midst of stopping. @@ -526,10 +526,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) - if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") - } - _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 915ef81b4eae..175756b80b6b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -255,6 +255,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } + + if (proxyUser != null && principal != null) { + SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + } } private def validateKillArguments(): Unit = { @@ -517,6 +521,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | | --proxy-user NAME User to impersonate when submitting the application. + | This argument does not work with --principal / --keytab. | | --help, -h Show this help message and exit | --verbose, -v Print additional debug output diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a602fcac68a6..a959f200d4cc 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -114,6 +114,19 @@ private[spark] class Executor( private val heartbeatReceiverRef = RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + /** + * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES` + * times, it should kill itself. The default value is 60. It means we will retry to send + * heartbeats about 10 minutes because the heartbeat interval is 10s. + */ + private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60) + + /** + * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each + * successful heartbeat will reset it to 0. + */ + private var heartbeatFailures = 0 + startDriverHeartbeater() def launchTask( @@ -461,8 +474,16 @@ private[spark] class Executor( logInfo("Told to re-register on heartbeat") env.blockManager.reregister() } + heartbeatFailures = 0 } catch { - case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) + case NonFatal(e) => + logWarning("Issue communicating with driver in heartbeater", e) + heartbeatFailures += 1 + if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) { + logError(s"Exit as unable to send heartbeats to driver " + + s"more than $HEARTBEAT_MAX_FAILURES times") + System.exit(ExecutorExitCode.HEARTBEAT_FAILURE) + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ea36fb60bd54..99858f785600 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -39,6 +39,12 @@ object ExecutorExitCode { /** ExternalBlockStore failed to create a local temporary directory after many attempts. */ val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55 + /** + * Executor is unable to send heartbeats to the driver more than + * "spark.executor.heartbeat.maxFailures" times. + */ + val HEARTBEAT_FAILURE = 56 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -51,6 +57,8 @@ object ExecutorExitCode { // TODO: replace external block store with concrete implementation name case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR => "ExternalBlockStore failed to create a local temporary directory." + case HEARTBEAT_FAILURE => + "Unable to send heartbeats to driver." case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a2..6103a10ccc50 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2195,6 +2195,16 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + + /** + * + * @return whether it is local mode + */ + def isLocalMaster(conf: SparkConf): Boolean = { + val master = conf.get("spark.master", "") + master == "local" || master.startsWith("local[") + } + /** * Return whether dynamic allocation is enabled in the given conf * Dynamic allocation and explicitly setting the number of executors are inherently @@ -2202,8 +2212,13 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.getBoolean("spark.dynamicAllocation.enabled", false) && - conf.getInt("spark.executor.instances", 0) == 0 + val numExecutor = conf.getInt("spark.executor.instances", 0) + val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) + if (numExecutor != 0 && dynamicAllocationEnabled) { + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + numExecutor == 0 && dynamicAllocationEnabled && + (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 7c6778b06546..412c0ac9d9be 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -722,6 +722,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() + conf.set("spark.master", "yarn-client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) @@ -731,6 +732,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "1")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) + assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) } test("encodeFileNameToURIRawPath") { diff --git a/dev/run-tests.py b/dev/run-tests.py index 6febbf108900..b65d1a309cb4 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -488,7 +488,7 @@ def main(): if which("R"): run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) else: - print("Can't install SparkR as R is was not found in PATH") + print("Cannot install SparkR as R was not found in PATH") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4e04672ad39e..e4f2edaf9511 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -477,7 +477,7 @@ def __hash__(self): ], sbt_test_goals=[ "yarn/test", - "common/network-yarn/test", + "network-yarn/test", ], test_tags=[ "org.apache.spark.tags.ExtendedYarnTest" diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java new file mode 100644 index 000000000000..e124c1cf1855 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans; +import org.apache.spark.ml.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + + +/** + * An example demonstrating a bisecting k-means clustering. + */ +public class JavaBisectingKMeansExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), + RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), + RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), + RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), + RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), + RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + BisectingKMeans bkm = new BisectingKMeans().setK(2); + BisectingKMeansModel model = bkm.fit(dataset); + + System.out.println("Compute Cost: " + model.computeCost(dataset)); + + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java index 0001500f4fa5..c600094947d5 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -33,7 +33,7 @@ // $example off$ /** - * Java example for graph clustering using power iteration clustering (PIC). + * Java example for bisecting k-means clustering. */ public class JavaBisectingKMeansExample { public static void main(String[] args) { @@ -54,9 +54,7 @@ public static void main(String[] args) { BisectingKMeansModel model = bkm.run(data); System.out.println("Compute Cost: " + model.computeCost(data)); - for (Vector center: model.clusterCenters()) { - System.out.println(""); - } + Vector[] clusterCenters = model.clusterCenters(); for (int i = 0; i < clusterCenters.length; i++) { Vector clusterCenter = clusterCenters[i]; diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 0ec6214fdef1..6442b2a4e294 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -62,7 +62,7 @@ object TwitterAlgebirdHLL { var userSet: Set[Long] = Set() val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll(id)) + ids.map(id => hll.create(id)) }).reduce(_ + _) val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 61b364213181..55b751065664 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { + /** + * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]] + * only supports the number of features is no more than 4096. + */ + val MAX_NUM_FEATURES: Int = 4096 + /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ @@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares { private var aaSum: DenseVector = _ private def init(k: Int): Unit = { - require(k <= 4096, "In order to take the normal equation approach efficiently, " + - s"we set the max number of features to 4096 but got $k.") + require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " + + s"we set the max number of features to $MAX_NUM_FEATURES but got $k.") this.k = k triK = k * (k + 1) / 2 count = 0L diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala new file mode 100644 index 000000000000..a850dfee0a45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.stats.distributions.{Gaussian => GD} + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.optim._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ + +/** + * Params for Generalized Linear Regression. + */ +private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams + with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol + with HasSolver with Logging { + + /** + * Param for the name of family which is a description of the error distribution + * to be used in the model. + * Supported options: "gaussian", "binomial", "poisson" and "gamma". + * Default is "gaussian". + * @group param + */ + @Since("2.0.0") + final val family: Param[String] = new Param(this, "family", + "The name of family which is a description of the error distribution to be used in the " + + "model. Supported options: gaussian(default), binomial, poisson and gamma.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getFamily: String = $(family) + + /** + * Param for the name of link function which provides the relationship + * between the linear predictor and the mean of the distribution function. + * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * @group param + */ + @Since("2.0.0") + final val link: Param[String] = new Param(this, "link", "The name of link function " + + "which provides the relationship between the linear predictor and the mean of the " + + "distribution function. Supported options: identity, log, inverse, logit, probit, " + + "cloglog and sqrt.", + ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + + /** @group getParam */ + @Since("2.0.0") + def getLink: String = $(link) + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + override def validateParams(): Unit = { + if ($(solver) == "irls") { + setDefault(maxIter -> 25) + } + if (isDefined(link)) { + require(supportedFamilyAndLinkPairs.contains( + Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + + s"with ${$(family)} family does not support ${$(link)} link function.") + } + } +} + +/** + * :: Experimental :: + * + * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]]) + * specified by giving a symbolic description of the linear predictor (link function) and + * a description of the error distribution (family). + * It supports "gaussian", "binomial", "poisson" and "gamma" as family. + * Valid link functions for each family is listed below. The first link function of each family + * is the default one. + * - "gaussian" -> "identity", "log", "inverse" + * - "binomial" -> "logit", "probit", "cloglog" + * - "poisson" -> "log", "identity", "sqrt" + * - "gamma" -> "inverse", "identity", "log" + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) + extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase with Logging { + + import GeneralizedLinearRegression._ + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("glm")) + + /** + * Sets the value of param [[family]]. + * Default is "gaussian". + * @group setParam + */ + @Since("2.0.0") + def setFamily(value: String): this.type = set(family, value) + setDefault(family -> Gaussian.name) + + /** + * Sets the value of param [[link]]. + * @group setParam + */ + @Since("2.0.0") + def setLink(value: String): this.type = set(link, value) + + /** + * Sets if we should fit the intercept. + * Default is true. + * @group setParam + */ + @Since("2.0.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + + /** + * Sets the maximum number of iterations. + * Default is 25 if the solver algorithm is "irls". + * @group setParam + */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Sets the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Sets the regularization parameter. + * Default is 0.0. + * @group setParam + */ + @Since("2.0.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is empty, so all instances have weight one. + * @group setParam + */ + @Since("2.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + + /** + * Sets the solver algorithm used for optimization. + * Currently only support "irls" which is also the default solver. + * @group setParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + setDefault(solver -> "irls") + + override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + val familyObj = Family.fromName($(family)) + val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd + .map { case Row(features: Vector) => + features.size + }.first() + if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { + val msg = "Currently, GeneralizedLinearRegression only supports number of features" + + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." + throw new SparkException(msg) + } + + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd + .map { case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + if (familyObj == Gaussian && linkObj == Identity) { + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), + standardizeFeatures = true, standardizeLabel = true) + val wlsModel = optimizer.fit(instances) + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) + .setParent(this)) + return model + } + + // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, + $(fitIntercept), $(regParam), $(maxIter), $(tol)) + val irlsModel = optimizer.fit(instances) + + val model = copyValues( + new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) + .setParent(this)) + model + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra) +} + +@Since("2.0.0") +private[ml] object GeneralizedLinearRegression { + + /** Set of family and link pairs that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyAndLinkPairs = Set( + Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, + Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, + Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, + Gamma -> Inverse, Gamma -> Identity, Gamma -> Log + ) + + /** Set of family names that GeneralizedLinearRegression supports. */ + lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + + /** Set of link names that GeneralizedLinearRegression supports. */ + lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + + val epsilon: Double = 1E-16 + + /** + * Wrapper of family and link combination used in the model. + */ + private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + + /** Linear predictor based on given mu. */ + def predict(mu: Double): Double = link.link(family.project(mu)) + + /** Fitted value based on linear predictor eta. */ + def fitted(eta: Double): Double = family.project(link.unlink(eta)) + + /** + * Get the initial guess model for [[IterativelyReweightedLeastSquares]]. + */ + def initialize( + instances: RDD[Instance], + fitIntercept: Boolean, + regParam: Double): WeightedLeastSquaresModel = { + val newInstances = instances.map { instance => + val mu = family.initialize(instance.label, instance.weight) + val eta = predict(mu) + Instance(eta, instance.weight, instance.features) + } + // TODO: Make standardizeFeatures and standardizeLabel configurable. + val initialModel = new WeightedLeastSquares(fitIntercept, regParam, + standardizeFeatures = true, standardizeLabel = true) + .fit(newInstances) + initialModel + } + + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + } + + /** + * A description of the error distribution to be used in the model. + * @param name the name of the family. + */ + private[ml] abstract class Family(val name: String) extends Serializable { + + /** The default link instance of this family. */ + val defaultLink: Link + + /** Initialize the starting value for mu. */ + def initialize(y: Double, weight: Double): Double + + /** The variance of the endogenous variable's mean, given the value mu. */ + def variance(mu: Double): Double + + /** Trim the fitted value so that it will be in valid range. */ + def project(mu: Double): Double = mu + } + + private[ml] object Family { + + /** + * Gets the [[Family]] object from its name. + * @param name family name: "gaussian", "binomial", "poisson" or "gamma". + */ + def fromName(name: String): Family = { + name match { + case Gaussian.name => Gaussian + case Binomial.name => Binomial + case Poisson.name => Poisson + case Gamma.name => Gamma + } + } + } + + /** + * Gaussian exponential family distribution. + * The default link for the Gaussian family is the identity link. + */ + private[ml] object Gaussian extends Family("gaussian") { + + val defaultLink: Link = Identity + + override def initialize(y: Double, weight: Double): Double = y + + def variance(mu: Double): Double = 1.0 + + override def project(mu: Double): Double = { + if (mu.isNegInfinity) { + Double.MinValue + } else if (mu.isPosInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Binomial exponential family distribution. + * The default link for the Binomial family is the logit link. + */ + private[ml] object Binomial extends Family("binomial") { + + val defaultLink: Link = Logit + + override def initialize(y: Double, weight: Double): Double = { + val mu = (weight * y + 0.5) / (weight + 1.0) + require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" + + s"should be in range (0, 1), but got $mu") + mu + } + + override def variance(mu: Double): Double = mu * (1.0 - mu) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu > 1.0 - epsilon) { + 1.0 - epsilon + } else { + mu + } + } + } + + /** + * Poisson exponential family distribution. + * The default link for the Poisson family is the log link. + */ + private[ml] object Poisson extends Family("poisson") { + + val defaultLink: Link = Log + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Poisson family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = mu + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * Gamma exponential family distribution. + * The default link for the Gamma family is the inverse link. + */ + private[ml] object Gamma extends Family("gamma") { + + val defaultLink: Link = Inverse + + override def initialize(y: Double, weight: Double): Double = { + require(y > 0.0, "The response variable of Gamma family " + + s"should be positive, but got $y") + y + } + + override def variance(mu: Double): Double = math.pow(mu, 2.0) + + override def project(mu: Double): Double = { + if (mu < epsilon) { + epsilon + } else if (mu.isInfinity) { + Double.MaxValue + } else { + mu + } + } + } + + /** + * A description of the link function to be used in the model. + * The link function provides the relationship between the linear predictor + * and the mean of the distribution function. + * @param name the name of link function. + */ + private[ml] abstract class Link(val name: String) extends Serializable { + + /** The link function. */ + def link(mu: Double): Double + + /** Derivative of the link function. */ + def deriv(mu: Double): Double + + /** The inverse link function. */ + def unlink(eta: Double): Double + } + + private[ml] object Link { + + /** + * Gets the [[Link]] object from its name. + * @param name link name: "identity", "logit", "log", + * "inverse", "probit", "cloglog" or "sqrt". + */ + def fromName(name: String): Link = { + name match { + case Identity.name => Identity + case Logit.name => Logit + case Log.name => Log + case Inverse.name => Inverse + case Probit.name => Probit + case CLogLog.name => CLogLog + case Sqrt.name => Sqrt + } + } + } + + private[ml] object Identity extends Link("identity") { + + override def link(mu: Double): Double = mu + + override def deriv(mu: Double): Double = 1.0 + + override def unlink(eta: Double): Double = eta + } + + private[ml] object Logit extends Link("logit") { + + override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) + + override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) + } + + private[ml] object Log extends Link("log") { + + override def link(mu: Double): Double = math.log(mu) + + override def deriv(mu: Double): Double = 1.0 / mu + + override def unlink(eta: Double): Double = math.exp(eta) + } + + private[ml] object Inverse extends Link("inverse") { + + override def link(mu: Double): Double = 1.0 / mu + + override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0) + + override def unlink(eta: Double): Double = 1.0 / eta + } + + private[ml] object Probit extends Link("probit") { + + override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu) + + override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu)) + + override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta) + } + + private[ml] object CLogLog extends Link("cloglog") { + + override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) + + override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu)) + + override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) + } + + private[ml] object Sqrt extends Link("sqrt") { + + override def link(mu: Double): Double = math.sqrt(mu) + + override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu)) + + override def unlink(eta: Double): Double = math.pow(eta, 2.0) + } +} + +/** + * :: Experimental :: + * Model produced by [[GeneralizedLinearRegression]]. + */ +@Experimental +@Since("2.0.0") +class GeneralizedLinearRegressionModel private[ml] ( + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("2.0.0") val intercept: Double) + extends RegressionModel[Vector, GeneralizedLinearRegressionModel] + with GeneralizedLinearRegressionBase { + + import GeneralizedLinearRegression._ + + lazy val familyObj = Family.fromName($(family)) + lazy val linkObj = if (isDefined(link)) { + Link.fromName($(link)) + } else { + familyObj.defaultLink + } + lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + + override protected def predict(features: Vector): Double = { + val eta = BLAS.dot(features, coefficients) + intercept + familyAndLink.fitted(eta) + } + + @Since("2.0.0") + override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { + copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) + .setParent(parent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8f78fd122f34..b4f17b8e2898 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String }.first() val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) || - $(solver) == "normal") { + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && + numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c3882606d7db..f807b5683c39 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -408,6 +408,10 @@ class LogisticRegressionWithLBFGS * defaults to the mllib implementation. If more than two classes * or feature scaling is disabled, always uses mllib implementation. * Uses user provided weights. + * + * In the ml LogisticRegression implementation, the number of corrections + * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()` + * will have no effect if we fall into that route. */ override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { run(input, initialWeights, userSuppliedWeights = true) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala new file mode 100644 index 000000000000..8bfa9855ce4e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} +import org.apache.spark.mllib.random._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + private val seed: Int = 42 + @transient var datasetGaussianIdentity: DataFrame = _ + @transient var datasetGaussianLog: DataFrame = _ + @transient var datasetGaussianInverse: DataFrame = _ + @transient var datasetBinomial: DataFrame = _ + @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonIdentity: DataFrame = _ + @transient var datasetPoissonSqrt: DataFrame = _ + @transient var datasetGammaInverse: DataFrame = _ + @transient var datasetGammaIdentity: DataFrame = _ + @transient var datasetGammaLog: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + import GeneralizedLinearRegressionSuite._ + + datasetGaussianIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "identity"), 2)) + + datasetGaussianLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "log"), 2)) + + datasetGaussianInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gaussian", link = "inverse"), 2)) + + datasetBinomial = { + val nPoints = 10000 + val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + + val testData = + generateMultinomialLogisticInput(coefficients, xMean, xVariance, + addIntercept = true, nPoints, seed) + + sqlContext.createDataFrame(sc.parallelize(testData, 2)) + } + + datasetPoissonLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "log"), 2)) + + datasetPoissonIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "identity"), 2)) + + datasetPoissonSqrt = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "poisson", link = "sqrt"), 2)) + + datasetGammaInverse = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "inverse"), 2)) + + datasetGammaIdentity = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "identity"), 2)) + + datasetGammaLog = sqlContext.createDataFrame( + sc.parallelize(generateGeneralizedLinearRegressionInput( + intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, + family = "gamma", link = "log"), 2)) + } + + test("params") { + ParamsSuite.checkParams(new GeneralizedLinearRegression) + val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("generalized linear regression: default params") { + val glr = new GeneralizedLinearRegression + assert(glr.getLabelCol === "label") + assert(glr.getFeaturesCol === "features") + assert(glr.getPredictionCol === "prediction") + assert(glr.getFitIntercept) + assert(glr.getTol === 1E-6) + assert(glr.getWeightCol === "") + assert(glr.getRegParam === 0.0) + assert(glr.getSolver == "irls") + // TODO: Construct model directly instead of via fitting. + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.getFamily === "gaussian") + assert(model.getLink === "identity") + } + + test("generalized linear regression: gaussian family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="gaussian", data=data) + print(as.vector(coef(model))) + } + + [1] 2.2960999 0.8087933 + [1] 2.5002642 2.2000403 0.5999485 + + data <- read.csv("path", header=FALSE) + model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0)) + model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0)) + print(as.vector(coef(model1))) + print(as.vector(coef(model2))) + + [1] 0.23069326 0.07993778 + [1] 0.25001858 0.22002452 0.05998789 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=gaussian(link=inverse), data=data) + print(as.vector(coef(model))) + } + + [1] 2.3010179 0.8198976 + [1] 2.4108902 2.2130248 0.6086152 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2960999, 0.8087933), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(0.0, 0.23069326, 0.07993778), + Vectors.dense(0.25001858, 0.22002452, 0.05998789), + Vectors.dense(0.0, 2.3010179, 0.8198976), + Vectors.dense(2.4108902, 2.2130248, 0.6086152)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog), + ("inverse", datasetGaussianInverse))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gaussian family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gaussian family against glmnet") { + /* + R code: + library(glmnet) + data <- read.csv("path", header=FALSE) + label = data$V1 + features = as.matrix(data.frame(data$V2, data$V3)) + for (intercept in c(FALSE, TRUE)) { + for (lambda in c(0.0, 0.1, 1.0)) { + model <- glmnet(features, label, family="gaussian", intercept=intercept, + lambda=lambda, alpha=0, thresh=1E-14) + print(as.vector(coef(model))) + } + } + + [1] 0.0000000 2.2961005 0.8087932 + [1] 0.0000000 2.2130368 0.8309556 + [1] 0.0000000 1.7176137 0.9610657 + [1] 2.5002642 2.2000403 0.5999485 + [1] 3.1106389 2.0935142 0.5712711 + [1] 6.7597127 1.4581054 0.3994266 + */ + + val expected = Seq( + Vectors.dense(0.0, 2.2961005, 0.8087932), + Vectors.dense(0.0, 2.2130368, 0.8309556), + Vectors.dense(0.0, 1.7176137, 0.9610657), + Vectors.dense(2.5002642, 2.2000403, 0.5999485), + Vectors.dense(3.1106389, 2.0935142, 0.5712711), + Vectors.dense(6.7597127, 1.4581054, 0.3994266)) + + var idx = 0 + for (fitIntercept <- Seq(false, true); + regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression().setFamily("gaussian") + .setFitIntercept(fitIntercept).setRegParam(regParam) + val model = trainer.fit(datasetGaussianIdentity) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " + + s"fitIntercept = $fitIntercept and regParam = $regParam.") + + idx += 1 + } + } + + test("generalized linear regression: binomial family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 + data <- read.csv("path", header=FALSE) + + for (formula in c(f1, f2)) { + model <- glm(formula, family="binomial", data=data) + print(as.vector(coef(model))) + } + + [1] -0.3560284 1.3010002 -0.3570805 -0.7406762 + [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=probit), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2134390 0.7800646 -0.2144267 -0.4438358 + [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850 + + for (formula in c(f1, f2)) { + model <- glm(formula, family=binomial(link=cloglog), data=data) + print(as.vector(coef(model))) + } + + [1] -0.2832198 0.8434144 -0.2524727 -0.5293452 + [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758 + */ + val expected = Seq( + Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762), + Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989), + Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358), + Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850), + Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452), + Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial), + ("cloglog", datasetBinomial))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1), + model.coefficients(2), model.coefficients(3)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Binomial, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"binomial family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: poisson family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + + [1] 0.22999393 0.08047088 + [1] 0.25022353 0.21998599 0.05998621 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2929501 0.8119415 + [1] 2.5012730 2.1999407 0.5999107 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=poisson(link=sqrt), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2958947 0.8090515 + [1] 2.5000480 2.1999972 0.5999968 + */ + val expected = Seq( + Vectors.dense(0.0, 0.22999393, 0.08047088), + Vectors.dense(0.25022353, 0.21998599, 0.05998621), + Vectors.dense(0.0, 2.2929501, 0.8119415), + Vectors.dense(2.5012730, 2.1999407, 0.5999107), + Vectors.dense(0.0, 2.2958947, 0.8090515), + Vectors.dense(2.5000480, 2.1999972, 0.5999968)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity), + ("sqrt", datasetPoissonSqrt))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Poisson, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"poisson family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } + + test("generalized linear regression: gamma family against glm") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="Gamma", data=data) + print(as.vector(coef(model))) + } + + [1] 2.3392419 0.8058058 + [1] 2.3507700 2.2533574 0.6042991 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=identity), data=data) + print(as.vector(coef(model))) + } + + [1] 2.2908883 0.8147796 + [1] 2.5002406 2.1998346 0.6000059 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family=Gamma(link=log), data=data) + print(as.vector(coef(model))) + } + + [1] 0.22958970 0.08091066 + [1] 0.25003210 0.21996957 0.06000215 + */ + val expected = Seq( + Vectors.dense(0.0, 2.3392419, 0.8058058), + Vectors.dense(2.3507700, 2.2533574, 0.6042991), + Vectors.dense(0.0, 2.2908883, 0.8147796), + Vectors.dense(2.5002406, 2.1998346, 0.6000059), + Vectors.dense(0.0, 0.22958970, 0.08091066), + Vectors.dense(0.25003210, 0.21996957, 0.06000215)) + + import GeneralizedLinearRegression._ + + var idx = 0 + for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), + ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " + + s"$link link and fitIntercept = $fitIntercept.") + + val familyLink = new FamilyAndLink(Gamma, Link.fromName(link)) + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val eta = BLAS.dot(features, model.coefficients) + model.intercept + val prediction2 = familyLink.fitted(eta) + assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " + + s"gamma family, $link link and fitIntercept = $fitIntercept.") + } + + idx += 1 + } + } + } +} + +object GeneralizedLinearRegressionSuite { + + def generateGeneralizedLinearRegressionInput( + intercept: Double, + coefficients: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + noiseLevel: Double, + family: String, + link: String): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + def rndElement(i: Int) = { + (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } + val (generator, mean) = family match { + case "gaussian" => (new StandardNormalGenerator, 0.0) + case "poisson" => (new PoissonGenerator(1.0), 1.0) + case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0) + } + generator.setSeed(seed) + + (0 until nPoints).map { _ => + val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept + val mu = link match { + case "identity" => eta + case "log" => math.exp(eta) + case "sqrt" => math.pow(eta, 2.0) + case "inverse" => 1.0 / eta + } + val label = mu + noiseLevel * (generator.nextValue() - mean) + // Return LabeledPoints with DenseVector + LabeledPoint(label, features) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala index 77a2773c36f5..dcb1f398b04b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() /* Verify results using the `R` code: + library(arules) transactions = as(sapply( list("r z h k p", "z y x w v u t s", @@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { FUN=function(x) strsplit(x," ",fixed=TRUE)), "transactions") ars = apriori(transactions, - parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2)) arsDF = as(ars, "data.frame") arsDF$support = arsDF$support * length(transactions) names(arsDF)[names(arsDF) == "support"] = "freq" diff --git a/pom.xml b/pom.xml index 2376e307ced1..2148379896d3 100644 --- a/pom.xml +++ b/pom.xml @@ -89,7 +89,8 @@ common/sketch common/network-common common/network-shuffle - tags + common/unsafe + common/tags core graphx mllib @@ -99,7 +100,6 @@ sql/core sql/hive docker-integration-tests - unsafe assembly external/twitter external/flume diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 3179fb30ab4d..253af15cb5cd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,11 +26,12 @@ from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', - 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel', 'MultilayerPerceptronClassifier', - 'MultilayerPerceptronClassificationModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', + 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel', + 'NaiveBayes', 'NaiveBayesModel', + 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel'] @inherit_doc diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 611b9190491c..1cea477acb47 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -21,7 +21,8 @@ from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -__all__ = ['KMeans', 'KMeansModel', 'BisectingKMeans', 'BisectingKMeansModel'] +__all__ = ['BisectingKMeans', 'BisectingKMeansModel', + 'KMeans', 'KMeansModel'] class KMeansModel(JavaModel, MLWritable, MLReadable): diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 369f3508fda5..fb31c7310c0a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,15 +27,34 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', - 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', - 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinMaxScaler', 'MinMaxScalerModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', - 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', - 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', 'ChiSqSelector', - 'ChiSqSelectorModel'] +__all__ = ['Binarizer', + 'Bucketizer', + 'ChiSqSelector', 'ChiSqSelectorModel', + 'CountVectorizer', 'CountVectorizerModel', + 'DCT', + 'ElementwiseProduct', + 'HashingTF', + 'IDF', 'IDFModel', + 'IndexToString', + 'MaxAbsScaler', 'MaxAbsScalerModel', + 'MinMaxScaler', 'MinMaxScalerModel', + 'NGram', + 'Normalizer', + 'OneHotEncoder', + 'PCA', 'PCAModel', + 'PolynomialExpansion', + 'QuantileDiscretizer', + 'RegexTokenizer', + 'RFormula', 'RFormulaModel', + 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', + 'StringIndexer', 'StringIndexerModel', + 'Tokenizer', + 'VectorAssembler', + 'VectorIndexer', 'VectorIndexerModel', + 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b4d54ef61b0e..57106f8690a7 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -294,7 +294,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: 0.01) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -326,7 +326,7 @@ class LogisticRegressionWithLBFGS(object): """ @classmethod @since('1.2.0') - def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", + def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2", intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2): """ Train a logistic regression model on the given data. @@ -341,10 +341,10 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: None) :param regParam: The regularizer parameter. - (default: 0.01) + (default: 0.0) :param regType: The type of regularizer used for training our model. - Allowed values: + Supported values: - "l1" for using L1 regularization - "l2" for using L2 regularization (default) @@ -356,7 +356,9 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType (default: False) :param corrections: The number of corrections used in the LBFGS update. - (default: 10) + If a known updater is used for binary classification, + it calls the ml implementation and this parameter will + have no effect. (default: 10) :param tolerance: The convergence tolerance of iterations for L-BFGS. (default: 1e-6) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 4dd7083d79c8..3b77a6200054 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,10 +37,11 @@ class LabeledPoint(object): """ Class that represents the features and labels of a data point. - :param label: Label for this data point. - :param features: Vector of features for this point (NumPy array, - list, pyspark.mllib.linalg.SparseVector, or scipy.sparse - column matrix) + :param label: + Label for this data point. + :param features: + Vector of features for this point (NumPy array, list, + pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix). Note: 'label' and 'features' are accessible as class attributes. @@ -66,8 +67,10 @@ class LinearModel(object): """ A linear model that has a vector of coefficients and an intercept. - :param weights: Weights computed for every feature. - :param intercept: Intercept computed for this model. + :param weights: + Weights computed for every feature. + :param intercept: + Intercept computed for this model. .. versionadded:: 0.9.0 """ @@ -217,19 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): class LinearRegressionWithSGD(object): """ - Train a linear regression model with no regularization using Stochastic Gradient Descent. - This solves the least squares regression formulation - - f(weights) = 1/n ||A weights-y||^2 - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @@ -237,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient - Descent (SGD). - This solves the least squares regression formulation - - f(weights) = 1/(2n) ||A weights - y||^2, - - which is the mean squared error. - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter - (default: 0.0). - :param regType: The type of regularizer used for - training our model. - - :Allowed values: - - "l1" for using L1 regularization (lasso), - - "l2" for using L2 regularization (ridge), - - None for no regularization - - (default: None) - - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Descent (SGD). This solves the least squares regression + formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + + which is the mean squared error. Here the data matrix has n rows, + and the input RDD holds the set of rows of A, each with its + corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param regParam: + The regularizer parameter. + (default: 0.0) + :param regType: + The type of regularizer used for training our model. + Supported values: + + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization (default) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e., whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -368,56 +365,53 @@ def load(cls, sc, path): class LassoWithSGD(object): """ - Train a regression model with L1-regularization using Stochastic Gradient Descent. - This solves the L1-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L1-regularization using - Stochastic Gradient Descent. - This solves the l1-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L1-regularization using Stochastic + Gradient Descent. This solves the l1-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -508,56 +502,53 @@ def load(cls, sc, path): class RidgeRegressionWithSGD(object): """ - Train a regression model with L2-regularization using Stochastic Gradient Descent. - This solves the L2-regularized least squares regression formulation - - f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 - - Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with - its corresponding right hand side label y. - See also the documentation for the precise formulation. - .. versionadded:: 0.9.0 """ - @classmethod @since("0.9.0") def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, validateData=True, convergenceTol=0.001): """ - Train a regression model with L2-regularization using - Stochastic Gradient Descent. - This solves the l2-regularized least squares regression - formulation - - f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. - - Here the data matrix has n rows, and the input RDD holds the - set of rows of A, each with its corresponding right hand side - label y. See also the documentation for the precise formulation. - - :param data: The training data, an RDD of - LabeledPoint. - :param iterations: The number of iterations - (default: 100). - :param step: The step parameter used in SGD - (default: 1.0). - :param regParam: The regularizer parameter - (default: 0.01). - :param miniBatchFraction: Fraction of data to be used for each - SGD iteration (default: 1.0). - :param initialWeights: The initial weights (default: None). - :param intercept: Boolean parameter which indicates the - use or not of the augmented representation - for training data (i.e. whether bias - features are activated or not, - default: False). - :param validateData: Boolean parameter which indicates if - the algorithm should validate data - before training. (default: True) - :param convergenceTol: A condition which decides iteration termination. - (default: 0.001) + Train a regression model with L2-regularization using Stochastic + Gradient Descent. This solves the l2-regularized least squares + regression formulation + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2 + + Here the data matrix has n rows, and the input RDD holds the set + of rows of A, each with its corresponding right hand side label y. + See also the documentation for the precise formulation. + + :param data: + The training data, an RDD of LabeledPoint. + :param iterations: + The number of iterations. + (default: 100) + :param step: + The step parameter used in SGD. + (default: 1.0) + :param regParam: + The regularizer parameter. + (default: 0.01) + :param miniBatchFraction: + Fraction of data to be used for each SGD iteration. + (default: 1.0) + :param initialWeights: + The initial weights. + (default: None) + :param intercept: + Boolean parameter which indicates the use or not of the + augmented representation for training data (i.e. whether bias + features are activated or not). + (default: False) + :param validateData: + Boolean parameter which indicates if the algorithm should + validate data before training. + (default: True) + :param convergenceTol: + A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), @@ -572,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader): """ Regression model for isotonic regression. - :param boundaries: Array of boundaries for which predictions are - known. Boundaries must be sorted in increasing order. - :param predictions: Array of predictions associated to the - boundaries at the same index. Results of isotonic - regression and therefore monotone. - :param isotonic: indicates whether this is isotonic or antitonic. + :param boundaries: + Array of boundaries for which predictions are known. Boundaries + must be sorted in increasing order. + :param predictions: + Array of predictions associated to the boundaries at the same + index. Results of isotonic regression and therefore monotone. + :param isotonic: + Indicates whether this is isotonic or antitonic. >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) @@ -628,7 +621,8 @@ def predict(self, x): values with the same boundary then the same rules as in 2) are used. - :param x: Feature or RDD of Features to be labeled. + :param x: + Feature or RDD of Features to be labeled. """ if isinstance(x, RDD): return x.map(lambda v: self.predict(v)) @@ -657,8 +651,8 @@ def load(cls, sc, path): class IsotonicRegression(object): """ Isotonic regression. - Currently implemented using parallelized pool adjacent violators algorithm. - Only univariate (single feature) algorithm supported. + Currently implemented using parallelized pool adjacent violators + algorithm. Only univariate (single feature) algorithm supported. Sequential PAV implementation based on: @@ -684,8 +678,11 @@ def train(cls, data, isotonic=True): """ Train a isotonic regression model on the given data. - :param data: RDD of (label, feature, weight) tuples. - :param isotonic: Whether this is isotonic or antitonic. + :param data: + RDD of (label, feature, weight) tuples. + :param isotonic: + Whether this is isotonic (which is default) or antitonic. + (default: True) """ boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", data.map(_convert_to_vector), bool(isotonic)) @@ -721,9 +718,11 @@ def _validate(self, dstream): @since("1.5.0") def predictOn(self, dstream): """ - Make predictions on a dstream. + Use the model to make predictions on batches of data from a + DStream. - :return: Transformed dstream object. + :return: + DStream containing predictions. """ self._validate(dstream) return dstream.map(lambda x: self._model.predict(x)) @@ -731,9 +730,11 @@ def predictOn(self, dstream): @since("1.5.0") def predictOnValues(self, dstream): """ - Make predictions on a keyed dstream. + Use the model to make predictions on the values of a DStream and + carry over its keys. - :return: Transformed dstream object. + :return: + DStream containing the input keys and the predictions as values. """ self._validate(dstream) return dstream.mapValues(lambda x: self._model.predict(x)) @@ -742,14 +743,15 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Train or predict a linear regression model on streaming data. Training uses - Stochastic Gradient Descent to update the model based on each new batch of - incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + Train or predict a linear regression model on streaming data. + Training uses Stochastic Gradient Descent to update the model + based on each new batch of incoming data from a DStream + (see `LinearRegressionWithSGD` for model equation). Each batch of data is assumed to be an RDD of LabeledPoints. The number of data points per batch can vary, but the number - of features must be constant. An initial weight - vector must be provided. + of features must be constant. An initial weight vector must + be provided. :param stepSize: Step size for each iteration of gradient descent. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b1453c637f79..7f5368d8bdbb 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -233,6 +233,23 @@ def text(self, paths): paths = [paths] return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(2.0) + def csv(self, paths): + """Loads a CSV file and returns the result as a [[DataFrame]]. + + This function goes through the input once to determine the input schema. To avoid going + through the entire data once, specify the schema explicitly using [[schema]]. + + :param paths: string, or list of strings, for input path(s). + + >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df.dtypes + [('C0', 'string'), ('C1', 'string')] + """ + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + @since(1.5) def orc(self, path): """Loads an ORC file, returning the result as a :class:`DataFrame`. @@ -448,6 +465,11 @@ def json(self, path, mode=None): * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + You can set the following JSON-specific option(s) for writing JSON files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode)._jwrite.json(path) @@ -476,11 +498,39 @@ def parquet(self, path, mode=None, partitionBy=None): def text(self, path): """Saves the content of the DataFrame in a text file at the specified path. + :param path: the path in any Hadoop supported file system + The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. + + You can set the following option(s) for writing text files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). """ self._jwrite.text(path) + @since(2.0) + def csv(self, path, mode=None): + """Saves the content of the [[DataFrame]] in CSV format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + You can set the following CSV-specific option(s) for writing CSV files: + * ``compression`` (default ``None``): compression codec to use when saving to file. + This can be one of the known case-insensitive shorten names + (``bzip2``, ``gzip``, ``lz4``, and ``snappy``). + + >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode)._jwrite.csv(path) + @since(1.5) def orc(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv new file mode 100644 index 000000000000..18991feda788 --- /dev/null +++ b/python/test_support/sql/ages.csv @@ -0,0 +1,4 @@ +Joe,20 +Tom,30 +Hyukjin,25 + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27ae62f1212f..0ad0f4976c77 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -110,8 +109,7 @@ private void cleanupResources() { sorter.cleanupResources(); } - @VisibleForTesting - Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -160,7 +158,6 @@ public UnsafeRow next() { } } - public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 23e4709bbd88..876aa0eae0e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -559,7 +561,13 @@ class Analyzer( } resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + case n: NewInstance + // If this is an inner class of another class, register the outer object in `OuterScopes`. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + if n.outerPointer.isEmpty && + n.cls.isMemberClass && + !Modifier.isStatic(n.cls.getModifiers) => val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 4be065b30a21..3ee19cc4ad71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.DecimalFormat +import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow @@ -938,8 +938,10 @@ case class FormatNumber(x: Expression, d: Expression) @transient private val pattern: StringBuffer = new StringBuffer() + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. @transient - private val numberFormat: DecimalFormat = new DecimalFormat("") + private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { val dValue = dObject.asInstanceOf[Int] @@ -962,10 +964,9 @@ case class FormatNumber(x: Expression, d: Expression) pattern.append("0") } } - val dFormat = new DecimalFormat(pattern.toString) lastDValue = dValue - numberFormat.applyPattern(dFormat.toPattern) + numberFormat.applyLocalizedPattern(pattern.toString) } x.dataType match { @@ -992,6 +993,11 @@ case class FormatNumber(x: Expression, d: Expression) val sb = classOf[StringBuffer].getName val df = classOf[DecimalFormat].getName + val dfs = classOf[DecimalFormatSymbols].getName + val l = classOf[Locale].getName + // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') + // as a decimal separator. + val usLocale = "US" val lastDValue = ctx.freshName("lastDValue") val pattern = ctx.freshName("pattern") val numberFormat = ctx.freshName("numberFormat") @@ -999,7 +1005,8 @@ case class FormatNumber(x: Expression, d: Expression) val dFormat = ctx.freshName("dFormat") ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") + ctx.addMutableState(df, numberFormat, + s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { @@ -1013,9 +1020,8 @@ case class FormatNumber(x: Expression, d: Expression) $pattern.append("0"); } } - $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; - $numberFormat.applyPattern($dFormat.toPattern()); + $numberFormat.applyLocalizedPattern($pattern.toString()); } ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8095083f336e..31e775d60f95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + override protected def validConstraints: Set[Expression] = child.constraints override def statistics: Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d2a65b716b2..e81a0f948746 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - private def getAliasedConstraints: Set[Expression] = { - projectList.flatMap { - case a @ Alias(e, _) => - child.constraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet - } - - override def validConstraints: Set[Expression] = { - child.constraints.union(getAliasedConstraints) - } + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - } // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. @@ -442,6 +423,9 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + override def statistics: Statistics = { if (groupingExpressions.isEmpty) { Statistics(sizeInBytes = 1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ed..6a59e9728a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 2e03ddae760b..9c1319c1c5e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + /** * Returns if dt is a DecimalType that fits inside a long */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 373b1ffa83d2..b68432b1a128 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index e7f0ec2e7789..57dbd7c2ff56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() > - CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) { throw new IOException("Decimal with high precision is not supported."); } if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); int precision = type.getDecimalMetadata().getPrecision(); int scale = type.getDecimalMetadata().getScale(); - Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(), "Unsupported precision."); for (int n = 0; n < num; ++n) { @@ -480,11 +479,6 @@ private final class ColumnReader { */ private boolean useDictionary; - /** - * If useDictionary is true, the staging vector used to decode the ids. - */ - private ColumnVector dictionaryIds; - /** * Maximum definition level for this column. */ @@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException { } int num = Math.min(total, leftInPage); if (useDictionary) { - // Data is dictionary encoded. We will vector decode the ids and then resolve the values. - if (dictionaryIds == null) { - dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(total); - } // Read and decode dictionary ids. + ColumnVector dictionaryIds = column.reserveDictionaryIds(total);; defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column); + decodeDictionaryIds(rowId, num, column, dictionaryIds); } else { + column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: - if (column.dataType() == DataTypes.IntegerType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ByteType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ShortType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case FLOAT: - for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); - } - break; - case DOUBLE: - for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); - } + case BINARY: + column.setDictionary(dictionary); break; case FIXED_LEN_BYTE_ARRAY: - if (DecimalType.is64BitDecimalType(column.dataType())) { + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); @@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } break; - case BINARY: - // TODO: this is incredibly inefficient as it blows up the dictionary right here. We - // need to do this better. We should probably add the dictionary data to the ColumnVector - // and reuse it across batches. This should mean adding a ByteArray would just update - // the length and offset. - for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); - } - break; - default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } @@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - defColumn.readIntsAsLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ShortType) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num, VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (DecimalType.is64BitDecimalType(column.dataType())) { + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 8613fcae0b80..62157389013b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c, } } - public void readIntsAsLongs(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { - int left = total; - while (left > 0) { - if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - switch (mode) { - case RLE: - if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putLong(rowId + i, data.readInteger()); - } - } else { - c.putNulls(rowId, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, data.readInteger()); - } else { - c.putNull(rowId + i); - } - } - break; - } - rowId += n; - left -= n; - currentCount -= n; - } - } - public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0514252a8e53..bb0247c2fbed 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -27,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. @@ -157,7 +159,7 @@ public Object[] array() { } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { - list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + list[i] = getUTF8String(i).toString(); } } } else if (dt instanceof CalendarIntervalType) { @@ -204,28 +206,17 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return data.getDecimal(offset + ordinal, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - Array child = data.getByteArray(offset + ordinal); - return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + return data.getUTF8String(offset + ordinal); } @Override public byte[] getBinary(int ordinal) { - ColumnVector.Array array = data.getByteArray(offset + ordinal); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return data.getBinary(offset + ordinal); } @Override @@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - public final Array getByteArray(int rowId) { + private Array getByteArray(int rowId) { Array array = getArray(rowId); array.data.loadBytes(array); return array; } + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.apply(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) { */ protected final ColumnarBatch.Row resultStruct; + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2aeef7f2f90f..681ace338713 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -22,24 +22,20 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.commons.lang.NotImplementedException; - /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly * for debugging or other non-performance critical paths. * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { - public static String toString(ColumnVector.Array a) { - return new String(a.byteArray, a.byteArrayOffset, a.length); - } - /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 070d897a7158..8a0d7f8b1237 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -31,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class is the in memory representation of rows as they are streamed through operators. It * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that @@ -193,29 +191,17 @@ public final boolean anyNull() { @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = columns[ordinal].getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + return columns[ordinal].getUTF8String(rowId); } @Override public final byte[] getBinary(int ordinal) { - ColumnVector.Array array = columns[ordinal].getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return columns[ordinal].getBinary(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e38ed051219b..b06b7f2457b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,25 +18,11 @@ import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - - import org.apache.commons.lang.NotImplementedException; -import org.apache.commons.lang.NotImplementedException; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. @@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return Platform.getByte(null, data + rowId); + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return Platform.getShort(null, data + 2 * rowId); + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return Platform.getInt(null, data + 4 * rowId); + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return Platform.getLong(null, data + 8 * rowId); + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public final float getFloat(int rowId) { - return Platform.getFloat(null, data + rowId * 4); + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } } @@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return Platform.getDouble(null, data + rowId * 8); + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) { } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || - type instanceof DateType) { + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3502d31bd1df..305e84a86bdc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,12 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.util.Arrays; + import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.util.Arrays; - /** * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. @@ -68,7 +67,6 @@ public final void close() { doubleData = null; } - // // APIs dealing with nulls // @@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return byteData[rowId]; + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return shortData[rowId]; + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } @@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return intData[rowId]; + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return longData[rowId]; + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { } @Override - public final float getFloat(int rowId) { return floatData[rowId]; } + public final float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } // // APIs dealing with doubles @@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return doubleData[rowId]; + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType || type instanceof DateType) { + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d82556..093504c765ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -453,6 +453,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * format("json").save(path) * }}} * + * You can set the following JSON-specific option(s) for writing JSON files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.4.0 */ def json(path: String): Unit = format("json").save(path) @@ -492,10 +496,29 @@ final class DataFrameWriter private[sql](df: DataFrame) { * df.write().text("/path/to/output") * }}} * + * You can set the following option(s) for writing text files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * * @since 1.6.0 */ def text(path: String): Unit = format("text").save(path) + /** + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
  • + * + * @since 2.0.0 + */ + def csv(path: String): Unit = format("csv").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 75cb6d1137c3..2ea889ea72c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -37,7 +39,7 @@ case class Sort( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryNode { + extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -50,34 +52,36 @@ case class Sort( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") child.execute().mapPartitionsInternal { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } + val sorter = createSorter() val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can @@ -93,4 +97,74 @@ case class Sort( sortedIterator } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd8c96d5fa1d..0255103b63d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -71,9 +71,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil - // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => - joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index afaddcf35775..cb68ca6ada36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ${code.trim} } } - """ + """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) @@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // There is an UnsafeRow already s""" |append($row.copy()); - """.stripMargin + """.stripMargin.trim } else { assert(input != null) if (input.nonEmpty) { @@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) s""" |${code.code.trim} |append(${code.value}.copy()); - """.stripMargin + """.stripMargin.trim } else { // There is no columns s""" |append(unsafeRow); - """.stripMargin + """.stripMargin.trim } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ace8cd7ad864..7f1ed28046b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ - private[csv] object CSVInferSchema { /** @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - StructField(thisHeader, rootType, nullable = true) + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } StructType(structFields) @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { } def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { - first.zipAll(second, NullType, NullType).map { case ((a, b)) => - val tpe = findTightestCommonType(a, b).getOrElse(StringType) - tpe match { - case _: NullType => StringType - case other => other - } + first.zipAll(second, NullType, NullType).map { case (a, b) => + findTightestCommonType(a, b).getOrElse(NullType) } } @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) + case (StringType, t2) => Some(StringType) + case (t1, StringType) => Some(StringType) // Promote numeric types to the highest of the two and all numeric types to unlimited decimal case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { } } - private[csv] object CSVTypeCast { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index ee6373d03e1f..9e336422d1f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -44,6 +44,12 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. + * + * Null value predicate is added to the first partition where clause to include + * the rows with null value for the partitions column. + * + * @param partitioning partition information to generate the where clause for each partition + * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -66,7 +72,7 @@ private[sql] object JDBCRelation { if (upperBound == null) { lowerBound } else if (lowerBound == null) { - upperBound + s"$upperBound or $column is null" } else { s"$lowerBound AND $upperBound" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 31a95ed46121..e59dbd6b3d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -48,10 +48,7 @@ private[sql] class JSONOptions( parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 42d89f4bf81d..8a128b4b6176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index ab4250d0adba..6f6340f541ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) case UINT_8 => typeNotSupported() case UINT_16 => typeNotSupported() case UINT_32 => typeNotSupported() @@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() @@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 3508220c9541..0252c79d8e14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -33,7 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 60155b32349a..8f3f6335e428 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -115,10 +115,7 @@ private[sql] class TextRelation( /** Write path. */ override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = job.getConfiguration - val compressionCodec = { - val name = parameters.get("compression").orElse(parameters.get("codec")) - name.map(CompressionCodecs.getCodecClassName) - } + val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala deleted file mode 100644 index df6dac88187c..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output: Seq[Attribute] = left.output - - /** The Streamed Relation */ - override def left: SparkPlan = streamed - - /** The Broadcast relation */ - override def right: SparkPlan = broadcast - - override def requiredChildDistribution: Seq[Distribution] = { - UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil - } - - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - val relation = broadcastedRelation.value - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < relation.length && !matched) { - if (boundCondition(joinedRow(streamedRow, relation(i)))) { - matched = true - } - i += 1 - } - if (matched) { - numOutputRows += 1 - } - matched - }) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d419528..45175d36d5c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv new file mode 100644 index 000000000000..02d29cabf95f --- /dev/null +++ b/sql/core/src/test/resources/simple_sparse.csv @@ -0,0 +1,5 @@ +A,B,C,D +1,,, +,1,, +,,1, +,,,1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 14fc37b64aa3..33df6375e3aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -621,12 +621,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.filter(_ => true), Some(1), Some(2), Some(3)) } + + test("SPARK-13540 Dataset of nested class defined in Scala object") { + checkAnswer( + Seq(OuterObject.InnerClass("foo")).toDS(), + OuterObject.InnerClass("foo")) + } } class OuterClass extends Serializable { case class InnerClass(a: String) } +object OuterObject { + case class InnerClass(a: String) +} + case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3dab848e7b03..5b98c11ef2a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -47,7 +47,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j @@ -67,7 +66,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), @@ -465,7 +464,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", - classOf[LeftSemiJoinBNL]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 6d6cc0186a96..2d3e34d0e129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/limit/sum") { + val N = 500 << 20 + runBenchmark("range/limit/sum", N) { + sqlContext.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + ignore("stat functions") { val N = 100 << 20 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9350205d791d..de371d85d9fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index a1796f132600..412f1b89beee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) } + + test("Merging Nulltypes should yeild Nulltype.") { + val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) + assert(mergedNullTypes.deep == Array(NullType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 7671bc106610..3ecbb14f2ea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val emptyFile = "empty.csv" private val commentsFile = "comments.csv" private val disableCommentsFile = "disable_comments.csv" + private val simpleSparseFile = "simple_sparse.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema.fieldNames.size === 1) } - test("DDL test with empty file") { sqlContext.sql(s""" |CREATE TEMPORARY TABLE carsTable @@ -268,9 +268,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)) cars.coalesce(1).write - .format("csv") .option("header", "true") - .save(csvDir) + .csv(csvDir) val carsCopy = sqlContext.read .format("csv") @@ -396,4 +395,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(carsCopy, withHeader = true) } } + + test("Schema inference correctly identifies the datatype when data is sparse.") { + val df = sqlContext.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(simpleSparseFile)) + + assert( + df.schema.fields.map(field => field.dataType).deep == + Array(IntegerType, IntegerType, IntegerType, IntegerType).deep) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index cef6b79a094d..281a2cffa894 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(batch.column(0).getByte(i) == 1) assert(batch.column(1).getInt(i) == 2) assert(batch.column(2).getLong(i) == 3) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + assert(batch.column(3).getUTF8String(i).toString == "abc") i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 355f916a9755..bc341db5571b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -95,15 +95,6 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using LeftSemiJoinBNL") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - test(s"$testName using BroadcastNestedLoopJoin build left") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index c49f2439fce4..5b4f6f1d2461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val df = sqlContext.range(10).sort('id) + testSparkPlanMetrics(df, 2, Map.empty) + } + test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8efdf8adb042..97638a66ab47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -370,7 +370,7 @@ object ColumnarBatchBenchmark { } i = 0 while (i < count) { - sum += column.getByteArray(i).length + sum += column.getUTF8String(i).numBytes() i += 1 } column.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 445f311107e3..b3c3e66fbcbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index f8a9a95c873a..30a5e2ea4acd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -171,6 +171,27 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( + "create table test.emp(name TEXT(32) NOT NULL," + + " theid INTEGER, \"Dept\" INTEGER)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('fred', 1, 10)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('mary', 2, null)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate() + conn.prepareStatement( + "insert into test.emp values ('kathy', null, null)").executeUpdate() + conn.commit() + + sql( + s""" + |CREATE TEMPORARY TABLE nullparts + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass', + |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4') + """.stripMargin.replaceAll("\n", " ")) + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -338,6 +359,23 @@ class JDBCSuite extends SparkFunSuite .collect().length === 3) } + test("Partioning on column that might have null values.") { + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + .collect().length === 4) + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + .collect().length === 4) + // partitioning on a nullable quoted column + assert( + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + .collect().length === 4) + } + + test("SELECT * on partitioned table with a nullable partioncolumn") { + assert(sql("SELECT * FROM nullparts").collect().size == 4) + } + test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() assert(rows.length === 1) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4c9432dbd6ab..aef78fdfd4c5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,7 +18,9 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.lang.reflect.UndeclaredThrowableException import java.nio.charset.StandardCharsets.UTF_8 +import java.security.PrivilegedExceptionAction import java.util.regex.Matcher import java.util.regex.Pattern @@ -194,7 +196,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { */ def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = { try { - obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName) + obtainTokenForHiveMetastoreInner(conf) } catch { case e: ClassNotFoundException => logInfo(s"Hive class not found $e") @@ -209,8 +211,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { * @param username the username of the principal requesting the delegating token. * @return a delegation token */ - private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration, - username: String): Option[Token[DelegationTokenIdentifier]] = { + private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration): + Option[Token[DelegationTokenIdentifier]] = { val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) // the hive configuration class is a subclass of Hadoop Configuration, so can be cast down @@ -225,11 +227,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // Check for local metastore if (metastoreUri.nonEmpty) { - require(username.nonEmpty, "Username undefined") val principalKey = "hive.metastore.kerberos.principal" val principal = hiveConf.getTrimmed(principalKey, "") require(principal.nonEmpty, "Hive principal $principalKey undefined") - logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri") + val currentUser = UserGroupInformation.getCurrentUser() + logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " + + s"$principal at $metastoreUri") val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") val closeCurrent = hiveClass.getMethod("closeCurrent") try { @@ -238,12 +241,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { classOf[String], classOf[String]) val getHive = hiveClass.getMethod("get", hiveConfClass) - // invoke - val hive = getHive.invoke(null, hiveConf) - val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String] - val hive2Token = new Token[DelegationTokenIdentifier]() - hive2Token.decodeFromUrlString(tokenStr) - Some(hive2Token) + doAsRealUser { + val hive = getHive.invoke(null, hiveConf) + val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal) + .asInstanceOf[String] + val hive2Token = new Token[DelegationTokenIdentifier]() + hive2Token.decodeFromUrlString(tokenStr) + Some(hive2Token) + } } finally { Utils.tryLogNonFatalError { closeCurrent.invoke(null) @@ -303,6 +308,25 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } } + /** + * Run some code as the real logged in user (which may differ from the current user, for + * example, when using proxying). + */ + private def doAsRealUser[T](fn: => T): T = { + val currentUser = UserGroupInformation.getCurrentUser() + val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) + + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { + realUser.doAs(new PrivilegedExceptionAction[T]() { + override def run(): T = fn + }) + } catch { + case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e) + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index d3acaf229cc8..9202bd892f01 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -255,7 +255,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging hadoopConf.set("hive.metastore.uris", "http://localhost:0") val util = new YarnSparkHadoopUtil assertNestedHiveException(intercept[InvocationTargetException] { - util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") + util.obtainTokenForHiveMetastoreInner(hadoopConf) }) assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf)