diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd
index f121b62a53d2..f301606933a9 100644
--- a/bin/spark-submit.cmd
+++ b/bin/spark-submit.cmd
@@ -20,4 +20,4 @@ rem
rem This is the entry point for running Spark submit. To avoid polluting the
rem environment, it just launches a new cmd to do the real work.
-cmd /V /E /C spark-submit2.cmd %*
+cmd /V /E /C "%~dp0spark-submit2.cmd" %*
diff --git a/tags/README.md b/common/tags/README.md
similarity index 100%
rename from tags/README.md
rename to common/tags/README.md
diff --git a/tags/pom.xml b/common/tags/pom.xml
similarity index 97%
rename from tags/pom.xml
rename to common/tags/pom.xml
index 3e8e6f618287..8e702b4fefe8 100644
--- a/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -23,7 +23,7 @@
org.apache.spark
spark-parent_2.11
2.0.0-SNAPSHOT
- ../pom.xml
+ ../../pom.xml
org.apache.spark
diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java
similarity index 100%
rename from tags/src/main/java/org/apache/spark/tags/DockerTest.java
rename to common/tags/src/main/java/org/apache/spark/tags/DockerTest.java
diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java
similarity index 100%
rename from tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java
rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java
diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java
similarity index 100%
rename from tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java
rename to common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java
diff --git a/unsafe/pom.xml b/common/unsafe/pom.xml
similarity index 98%
rename from unsafe/pom.xml
rename to common/unsafe/pom.xml
index 75fea556eeae..5250014739da 100644
--- a/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -23,7 +23,7 @@
org.apache.spark
spark-parent_2.11
2.0.0-SNAPSHOT
- ../pom.xml
+ ../../pom.xml
org.apache.spark
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
rename to common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
rename to common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
rename to common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
rename to common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
rename to common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
similarity index 100%
rename from unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
rename to common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
index 6195916195e3..167c8020850d 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
@@ -149,7 +149,8 @@ $(document).ready(function() {
{name: 'seventh'},
{name: 'eighth'},
],
- "autoWidth": false
+ "autoWidth": false,
+ "order": [[ 0, "desc" ]]
};
var rowGroupConf = {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a1fa266e183e..0e8b735b923b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -244,7 +244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] def eventLogDir: Option[URI] = _eventLogDir
private[spark] def eventLogCodec: Option[String] = _eventLogCodec
- def isLocal: Boolean = (master == "local" || master.startsWith("local["))
+ def isLocal: Boolean = Utils.isLocalMaster(_conf)
/**
* @return true if context is stopped or in the midst of stopping.
@@ -526,10 +526,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// Optionally scale number of executors dynamically based on workload. Exposed for testing.
val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf)
- if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
- logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.")
- }
-
_executorAllocationManager =
if (dynamicAllocationEnabled) {
Some(new ExecutorAllocationManager(this, listenerBus, _conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 915ef81b4eae..175756b80b6b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -255,6 +255,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
"either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.")
}
}
+
+ if (proxyUser != null && principal != null) {
+ SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.")
+ }
}
private def validateKillArguments(): Unit = {
@@ -517,6 +521,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
|
| --proxy-user NAME User to impersonate when submitting the application.
+ | This argument does not work with --principal / --keytab.
|
| --help, -h Show this help message and exit
| --verbose, -v Print additional debug output
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index a602fcac68a6..a959f200d4cc 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -114,6 +114,19 @@ private[spark] class Executor(
private val heartbeatReceiverRef =
RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
+ /**
+ * When an executor is unable to send heartbeats to the driver more than `HEARTBEAT_MAX_FAILURES`
+ * times, it should kill itself. The default value is 60. It means we will retry to send
+ * heartbeats about 10 minutes because the heartbeat interval is 10s.
+ */
+ private val HEARTBEAT_MAX_FAILURES = conf.getInt("spark.executor.heartbeat.maxFailures", 60)
+
+ /**
+ * Count the failure times of heartbeat. It should only be acessed in the heartbeat thread. Each
+ * successful heartbeat will reset it to 0.
+ */
+ private var heartbeatFailures = 0
+
startDriverHeartbeater()
def launchTask(
@@ -461,8 +474,16 @@ private[spark] class Executor(
logInfo("Told to re-register on heartbeat")
env.blockManager.reregister()
}
+ heartbeatFailures = 0
} catch {
- case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e)
+ case NonFatal(e) =>
+ logWarning("Issue communicating with driver in heartbeater", e)
+ heartbeatFailures += 1
+ if (heartbeatFailures >= HEARTBEAT_MAX_FAILURES) {
+ logError(s"Exit as unable to send heartbeats to driver " +
+ s"more than $HEARTBEAT_MAX_FAILURES times")
+ System.exit(ExecutorExitCode.HEARTBEAT_FAILURE)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index ea36fb60bd54..99858f785600 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -39,6 +39,12 @@ object ExecutorExitCode {
/** ExternalBlockStore failed to create a local temporary directory after many attempts. */
val EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR = 55
+ /**
+ * Executor is unable to send heartbeats to the driver more than
+ * "spark.executor.heartbeat.maxFailures" times.
+ */
+ val HEARTBEAT_FAILURE = 56
+
def explainExitCode(exitCode: Int): String = {
exitCode match {
case UNCAUGHT_EXCEPTION => "Uncaught exception"
@@ -51,6 +57,8 @@ object ExecutorExitCode {
// TODO: replace external block store with concrete implementation name
case EXTERNAL_BLOCK_STORE_FAILED_TO_CREATE_DIR =>
"ExternalBlockStore failed to create a local temporary directory."
+ case HEARTBEAT_FAILURE =>
+ "Unable to send heartbeats to driver."
case _ =>
"Unknown executor exit code (" + exitCode + ")" + (
if (exitCode > 128) {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e0c9bf02a1a2..6103a10ccc50 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2195,6 +2195,16 @@ private[spark] object Utils extends Logging {
isInDirectory(parent, child.getParentFile)
}
+
+ /**
+ *
+ * @return whether it is local mode
+ */
+ def isLocalMaster(conf: SparkConf): Boolean = {
+ val master = conf.get("spark.master", "")
+ master == "local" || master.startsWith("local[")
+ }
+
/**
* Return whether dynamic allocation is enabled in the given conf
* Dynamic allocation and explicitly setting the number of executors are inherently
@@ -2202,8 +2212,13 @@ private[spark] object Utils extends Logging {
* the latter should override the former (SPARK-9092).
*/
def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
- conf.getBoolean("spark.dynamicAllocation.enabled", false) &&
- conf.getInt("spark.executor.instances", 0) == 0
+ val numExecutor = conf.getInt("spark.executor.instances", 0)
+ val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false)
+ if (numExecutor != 0 && dynamicAllocationEnabled) {
+ logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.")
+ }
+ numExecutor == 0 && dynamicAllocationEnabled &&
+ (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false))
}
def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = {
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 7c6778b06546..412c0ac9d9be 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -722,6 +722,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
test("isDynamicAllocationEnabled") {
val conf = new SparkConf()
+ conf.set("spark.master", "yarn-client")
assert(Utils.isDynamicAllocationEnabled(conf) === false)
assert(Utils.isDynamicAllocationEnabled(
conf.set("spark.dynamicAllocation.enabled", "false")) === false)
@@ -731,6 +732,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
conf.set("spark.executor.instances", "1")) === false)
assert(Utils.isDynamicAllocationEnabled(
conf.set("spark.executor.instances", "0")) === true)
+ assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false)
+ assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true")))
}
test("encodeFileNameToURIRawPath") {
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 6febbf108900..b65d1a309cb4 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -488,7 +488,7 @@ def main():
if which("R"):
run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")])
else:
- print("Can't install SparkR as R is was not found in PATH")
+ print("Cannot install SparkR as R was not found in PATH")
if os.environ.get("AMPLAB_JENKINS"):
# if we're on the Amplab Jenkins build servers setup variables
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 4e04672ad39e..e4f2edaf9511 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -477,7 +477,7 @@ def __hash__(self):
],
sbt_test_goals=[
"yarn/test",
- "common/network-yarn/test",
+ "network-yarn/test",
],
test_tags=[
"org.apache.spark.tags.ExtendedYarnTest"
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
new file mode 100644
index 000000000000..e124c1cf1855
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+// $example on$
+import org.apache.spark.ml.clustering.BisectingKMeans;
+import org.apache.spark.ml.clustering.BisectingKMeansModel;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+
+/**
+ * An example demonstrating a bisecting k-means clustering.
+ */
+public class JavaBisectingKMeansExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // $example on$
+ JavaRDD data = jsc.parallelize(Arrays.asList(
+ RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)),
+ RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)),
+ RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)),
+ RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)),
+ RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)),
+ RowFactory.create(Vectors.dense(18.9, 20.0, 19.7))
+ ));
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("features", new VectorUDT(), false, Metadata.empty()),
+ });
+
+ DataFrame dataset = jsql.createDataFrame(data, schema);
+
+ BisectingKMeans bkm = new BisectingKMeans().setK(2);
+ BisectingKMeansModel model = bkm.fit(dataset);
+
+ System.out.println("Compute Cost: " + model.computeCost(dataset));
+
+ Vector[] clusterCenters = model.clusterCenters();
+ for (int i = 0; i < clusterCenters.length; i++) {
+ Vector clusterCenter = clusterCenters[i];
+ System.out.println("Cluster Center " + i + ": " + clusterCenter);
+ }
+ // $example off$
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java
index 0001500f4fa5..c600094947d5 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java
@@ -33,7 +33,7 @@
// $example off$
/**
- * Java example for graph clustering using power iteration clustering (PIC).
+ * Java example for bisecting k-means clustering.
*/
public class JavaBisectingKMeansExample {
public static void main(String[] args) {
@@ -54,9 +54,7 @@ public static void main(String[] args) {
BisectingKMeansModel model = bkm.run(data);
System.out.println("Compute Cost: " + model.computeCost(data));
- for (Vector center: model.clusterCenters()) {
- System.out.println("");
- }
+
Vector[] clusterCenters = model.clusterCenters();
for (int i = 0; i < clusterCenters.length; i++) {
Vector clusterCenter = clusterCenters[i];
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala
index 0ec6214fdef1..6442b2a4e294 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala
@@ -62,7 +62,7 @@ object TwitterAlgebirdHLL {
var userSet: Set[Long] = Set()
val approxUsers = users.mapPartitions(ids => {
- ids.map(id => hll(id))
+ ids.map(id => hll.create(id))
}).reduce(_ + _)
val exactUsers = users.map(id => Set(id)).reduce(_ ++ _)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 61b364213181..55b751065664 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -156,6 +156,12 @@ private[ml] class WeightedLeastSquares(
private[ml] object WeightedLeastSquares {
+ /**
+ * In order to take the normal equation approach efficiently, [[WeightedLeastSquares]]
+ * only supports the number of features is no more than 4096.
+ */
+ val MAX_NUM_FEATURES: Int = 4096
+
/**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
@@ -174,8 +180,8 @@ private[ml] object WeightedLeastSquares {
private var aaSum: DenseVector = _
private def init(k: Int): Unit = {
- require(k <= 4096, "In order to take the normal equation approach efficiently, " +
- s"we set the max number of features to 4096 but got $k.")
+ require(k <= MAX_NUM_FEATURES, "In order to take the normal equation approach efficiently, " +
+ s"we set the max number of features to $MAX_NUM_FEATURES but got $k.")
this.k = k
triK = k * (k + 1) / 2
count = 0L
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
new file mode 100644
index 000000000000..a850dfee0a45
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -0,0 +1,577 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import breeze.stats.distributions.{Gaussian => GD}
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.optim._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{BLAS, Vector}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+
+/**
+ * Params for Generalized Linear Regression.
+ */
+private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams
+ with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
+ with HasSolver with Logging {
+
+ /**
+ * Param for the name of family which is a description of the error distribution
+ * to be used in the model.
+ * Supported options: "gaussian", "binomial", "poisson" and "gamma".
+ * Default is "gaussian".
+ * @group param
+ */
+ @Since("2.0.0")
+ final val family: Param[String] = new Param(this, "family",
+ "The name of family which is a description of the error distribution to be used in the " +
+ "model. Supported options: gaussian(default), binomial, poisson and gamma.",
+ ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getFamily: String = $(family)
+
+ /**
+ * Param for the name of link function which provides the relationship
+ * between the linear predictor and the mean of the distribution function.
+ * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ * @group param
+ */
+ @Since("2.0.0")
+ final val link: Param[String] = new Param(this, "link", "The name of link function " +
+ "which provides the relationship between the linear predictor and the mean of the " +
+ "distribution function. Supported options: identity, log, inverse, logit, probit, " +
+ "cloglog and sqrt.",
+ ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getLink: String = $(link)
+
+ import GeneralizedLinearRegression._
+
+ @Since("2.0.0")
+ override def validateParams(): Unit = {
+ if ($(solver) == "irls") {
+ setDefault(maxIter -> 25)
+ }
+ if (isDefined(link)) {
+ require(supportedFamilyAndLinkPairs.contains(
+ Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
+ s"with ${$(family)} family does not support ${$(link)} link function.")
+ }
+ }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Fit a Generalized Linear Model ([[https://en.wikipedia.org/wiki/Generalized_linear_model]])
+ * specified by giving a symbolic description of the linear predictor (link function) and
+ * a description of the error distribution (family).
+ * It supports "gaussian", "binomial", "poisson" and "gamma" as family.
+ * Valid link functions for each family is listed below. The first link function of each family
+ * is the default one.
+ * - "gaussian" -> "identity", "log", "inverse"
+ * - "binomial" -> "logit", "probit", "cloglog"
+ * - "poisson" -> "log", "identity", "sqrt"
+ * - "gamma" -> "inverse", "identity", "log"
+ */
+@Experimental
+@Since("2.0.0")
+class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
+ extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
+ with GeneralizedLinearRegressionBase with Logging {
+
+ import GeneralizedLinearRegression._
+
+ @Since("2.0.0")
+ def this() = this(Identifiable.randomUID("glm"))
+
+ /**
+ * Sets the value of param [[family]].
+ * Default is "gaussian".
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setFamily(value: String): this.type = set(family, value)
+ setDefault(family -> Gaussian.name)
+
+ /**
+ * Sets the value of param [[link]].
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setLink(value: String): this.type = set(link, value)
+
+ /**
+ * Sets if we should fit the intercept.
+ * Default is true.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+
+ /**
+ * Sets the maximum number of iterations.
+ * Default is 25 if the solver algorithm is "irls".
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Sets the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-6.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
+
+ /**
+ * Sets the regularization parameter.
+ * Default is 0.0.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ setDefault(regParam -> 0.0)
+
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is empty, so all instances have weight one.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ setDefault(weightCol -> "")
+
+ /**
+ * Sets the solver algorithm used for optimization.
+ * Currently only support "irls" which is also the default solver.
+ * @group setParam
+ */
+ @Since("2.0.0")
+ def setSolver(value: String): this.type = set(solver, value)
+ setDefault(solver -> "irls")
+
+ override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ val familyObj = Family.fromName($(family))
+ val linkObj = if (isDefined(link)) {
+ Link.fromName($(link))
+ } else {
+ familyObj.defaultLink
+ }
+ val familyAndLink = new FamilyAndLink(familyObj, linkObj)
+
+ val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
+ .map { case Row(features: Vector) =>
+ features.size
+ }.first()
+ if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
+ val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
+ s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
+ throw new SparkException(msg)
+ }
+
+ val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
+ .map { case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
+
+ if (familyObj == Gaussian && linkObj == Identity) {
+ // TODO: Make standardizeFeatures and standardizeLabel configurable.
+ val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
+ standardizeFeatures = true, standardizeLabel = true)
+ val wlsModel = optimizer.fit(instances)
+ val model = copyValues(
+ new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
+ .setParent(this))
+ return model
+ }
+
+ // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
+ val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
+ val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc,
+ $(fitIntercept), $(regParam), $(maxIter), $(tol))
+ val irlsModel = optimizer.fit(instances)
+
+ val model = copyValues(
+ new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
+ .setParent(this))
+ model
+ }
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GeneralizedLinearRegression = defaultCopy(extra)
+}
+
+@Since("2.0.0")
+private[ml] object GeneralizedLinearRegression {
+
+ /** Set of family and link pairs that GeneralizedLinearRegression supports. */
+ lazy val supportedFamilyAndLinkPairs = Set(
+ Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
+ Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
+ Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
+ Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
+ )
+
+ /** Set of family names that GeneralizedLinearRegression supports. */
+ lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
+
+ /** Set of link names that GeneralizedLinearRegression supports. */
+ lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
+
+ val epsilon: Double = 1E-16
+
+ /**
+ * Wrapper of family and link combination used in the model.
+ */
+ private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
+
+ /** Linear predictor based on given mu. */
+ def predict(mu: Double): Double = link.link(family.project(mu))
+
+ /** Fitted value based on linear predictor eta. */
+ def fitted(eta: Double): Double = family.project(link.unlink(eta))
+
+ /**
+ * Get the initial guess model for [[IterativelyReweightedLeastSquares]].
+ */
+ def initialize(
+ instances: RDD[Instance],
+ fitIntercept: Boolean,
+ regParam: Double): WeightedLeastSquaresModel = {
+ val newInstances = instances.map { instance =>
+ val mu = family.initialize(instance.label, instance.weight)
+ val eta = predict(mu)
+ Instance(eta, instance.weight, instance.features)
+ }
+ // TODO: Make standardizeFeatures and standardizeLabel configurable.
+ val initialModel = new WeightedLeastSquares(fitIntercept, regParam,
+ standardizeFeatures = true, standardizeLabel = true)
+ .fit(newInstances)
+ initialModel
+ }
+
+ /**
+ * The reweight function used to update offsets and weights
+ * at each iteration of [[IterativelyReweightedLeastSquares]].
+ */
+ val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
+ (instance: Instance, model: WeightedLeastSquaresModel) => {
+ val eta = model.predict(instance.features)
+ val mu = fitted(eta)
+ val offset = eta + (instance.label - mu) * link.deriv(mu)
+ val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
+ (offset, weight)
+ }
+ }
+ }
+
+ /**
+ * A description of the error distribution to be used in the model.
+ * @param name the name of the family.
+ */
+ private[ml] abstract class Family(val name: String) extends Serializable {
+
+ /** The default link instance of this family. */
+ val defaultLink: Link
+
+ /** Initialize the starting value for mu. */
+ def initialize(y: Double, weight: Double): Double
+
+ /** The variance of the endogenous variable's mean, given the value mu. */
+ def variance(mu: Double): Double
+
+ /** Trim the fitted value so that it will be in valid range. */
+ def project(mu: Double): Double = mu
+ }
+
+ private[ml] object Family {
+
+ /**
+ * Gets the [[Family]] object from its name.
+ * @param name family name: "gaussian", "binomial", "poisson" or "gamma".
+ */
+ def fromName(name: String): Family = {
+ name match {
+ case Gaussian.name => Gaussian
+ case Binomial.name => Binomial
+ case Poisson.name => Poisson
+ case Gamma.name => Gamma
+ }
+ }
+ }
+
+ /**
+ * Gaussian exponential family distribution.
+ * The default link for the Gaussian family is the identity link.
+ */
+ private[ml] object Gaussian extends Family("gaussian") {
+
+ val defaultLink: Link = Identity
+
+ override def initialize(y: Double, weight: Double): Double = y
+
+ def variance(mu: Double): Double = 1.0
+
+ override def project(mu: Double): Double = {
+ if (mu.isNegInfinity) {
+ Double.MinValue
+ } else if (mu.isPosInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Binomial exponential family distribution.
+ * The default link for the Binomial family is the logit link.
+ */
+ private[ml] object Binomial extends Family("binomial") {
+
+ val defaultLink: Link = Logit
+
+ override def initialize(y: Double, weight: Double): Double = {
+ val mu = (weight * y + 0.5) / (weight + 1.0)
+ require(mu > 0.0 && mu < 1.0, "The response variable of Binomial family" +
+ s"should be in range (0, 1), but got $mu")
+ mu
+ }
+
+ override def variance(mu: Double): Double = mu * (1.0 - mu)
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu > 1.0 - epsilon) {
+ 1.0 - epsilon
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Poisson exponential family distribution.
+ * The default link for the Poisson family is the log link.
+ */
+ private[ml] object Poisson extends Family("poisson") {
+
+ val defaultLink: Link = Log
+
+ override def initialize(y: Double, weight: Double): Double = {
+ require(y > 0.0, "The response variable of Poisson family " +
+ s"should be positive, but got $y")
+ y
+ }
+
+ override def variance(mu: Double): Double = mu
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu.isInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * Gamma exponential family distribution.
+ * The default link for the Gamma family is the inverse link.
+ */
+ private[ml] object Gamma extends Family("gamma") {
+
+ val defaultLink: Link = Inverse
+
+ override def initialize(y: Double, weight: Double): Double = {
+ require(y > 0.0, "The response variable of Gamma family " +
+ s"should be positive, but got $y")
+ y
+ }
+
+ override def variance(mu: Double): Double = math.pow(mu, 2.0)
+
+ override def project(mu: Double): Double = {
+ if (mu < epsilon) {
+ epsilon
+ } else if (mu.isInfinity) {
+ Double.MaxValue
+ } else {
+ mu
+ }
+ }
+ }
+
+ /**
+ * A description of the link function to be used in the model.
+ * The link function provides the relationship between the linear predictor
+ * and the mean of the distribution function.
+ * @param name the name of link function.
+ */
+ private[ml] abstract class Link(val name: String) extends Serializable {
+
+ /** The link function. */
+ def link(mu: Double): Double
+
+ /** Derivative of the link function. */
+ def deriv(mu: Double): Double
+
+ /** The inverse link function. */
+ def unlink(eta: Double): Double
+ }
+
+ private[ml] object Link {
+
+ /**
+ * Gets the [[Link]] object from its name.
+ * @param name link name: "identity", "logit", "log",
+ * "inverse", "probit", "cloglog" or "sqrt".
+ */
+ def fromName(name: String): Link = {
+ name match {
+ case Identity.name => Identity
+ case Logit.name => Logit
+ case Log.name => Log
+ case Inverse.name => Inverse
+ case Probit.name => Probit
+ case CLogLog.name => CLogLog
+ case Sqrt.name => Sqrt
+ }
+ }
+ }
+
+ private[ml] object Identity extends Link("identity") {
+
+ override def link(mu: Double): Double = mu
+
+ override def deriv(mu: Double): Double = 1.0
+
+ override def unlink(eta: Double): Double = eta
+ }
+
+ private[ml] object Logit extends Link("logit") {
+
+ override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
+
+ override def deriv(mu: Double): Double = 1.0 / (mu * (1.0 - mu))
+
+ override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
+ }
+
+ private[ml] object Log extends Link("log") {
+
+ override def link(mu: Double): Double = math.log(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / mu
+
+ override def unlink(eta: Double): Double = math.exp(eta)
+ }
+
+ private[ml] object Inverse extends Link("inverse") {
+
+ override def link(mu: Double): Double = 1.0 / mu
+
+ override def deriv(mu: Double): Double = -1.0 * math.pow(mu, -2.0)
+
+ override def unlink(eta: Double): Double = 1.0 / eta
+ }
+
+ private[ml] object Probit extends Link("probit") {
+
+ override def link(mu: Double): Double = GD(0.0, 1.0).icdf(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / GD(0.0, 1.0).pdf(GD(0.0, 1.0).icdf(mu))
+
+ override def unlink(eta: Double): Double = GD(0.0, 1.0).cdf(eta)
+ }
+
+ private[ml] object CLogLog extends Link("cloglog") {
+
+ override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
+
+ override def deriv(mu: Double): Double = 1.0 / ((mu - 1.0) * math.log(1.0 - mu))
+
+ override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
+ }
+
+ private[ml] object Sqrt extends Link("sqrt") {
+
+ override def link(mu: Double): Double = math.sqrt(mu)
+
+ override def deriv(mu: Double): Double = 1.0 / (2.0 * math.sqrt(mu))
+
+ override def unlink(eta: Double): Double = math.pow(eta, 2.0)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model produced by [[GeneralizedLinearRegression]].
+ */
+@Experimental
+@Since("2.0.0")
+class GeneralizedLinearRegressionModel private[ml] (
+ @Since("2.0.0") override val uid: String,
+ @Since("2.0.0") val coefficients: Vector,
+ @Since("2.0.0") val intercept: Double)
+ extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
+ with GeneralizedLinearRegressionBase {
+
+ import GeneralizedLinearRegression._
+
+ lazy val familyObj = Family.fromName($(family))
+ lazy val linkObj = if (isDefined(link)) {
+ Link.fromName($(link))
+ } else {
+ familyObj.defaultLink
+ }
+ lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj)
+
+ override protected def predict(features: Vector): Double = {
+ val eta = BLAS.dot(features, coefficients) + intercept
+ familyAndLink.fitted(eta)
+ }
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
+ copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
+ .setParent(parent)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 8f78fd122f34..b4f17b8e2898 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -163,8 +163,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}.first()
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) ||
- $(solver) == "normal") {
+ if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
+ numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
"solver is used.'")
// For low dimensional data, WeightedLeastSquares is more efficiently since the
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index c3882606d7db..f807b5683c39 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -408,6 +408,10 @@ class LogisticRegressionWithLBFGS
* defaults to the mllib implementation. If more than two classes
* or feature scaling is disabled, always uses mllib implementation.
* Uses user provided weights.
+ *
+ * In the ml LogisticRegression implementation, the number of corrections
+ * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()`
+ * will have no effect if we fall into that route.
*/
override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = {
run(input, initialWeights, userSuppliedWeights = true)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
new file mode 100644
index 000000000000..8bfa9855ce4e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
+import org.apache.spark.mllib.random._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row}
+
+class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private val seed: Int = 42
+ @transient var datasetGaussianIdentity: DataFrame = _
+ @transient var datasetGaussianLog: DataFrame = _
+ @transient var datasetGaussianInverse: DataFrame = _
+ @transient var datasetBinomial: DataFrame = _
+ @transient var datasetPoissonLog: DataFrame = _
+ @transient var datasetPoissonIdentity: DataFrame = _
+ @transient var datasetPoissonSqrt: DataFrame = _
+ @transient var datasetGammaInverse: DataFrame = _
+ @transient var datasetGammaIdentity: DataFrame = _
+ @transient var datasetGammaLog: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ import GeneralizedLinearRegressionSuite._
+
+ datasetGaussianIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "identity"), 2))
+
+ datasetGaussianLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "log"), 2))
+
+ datasetGaussianInverse = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gaussian", link = "inverse"), 2))
+
+ datasetBinomial = {
+ val nPoints = 10000
+ val coefficients = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+
+ val testData =
+ generateMultinomialLogisticInput(coefficients, xMean, xVariance,
+ addIntercept = true, nPoints, seed)
+
+ sqlContext.createDataFrame(sc.parallelize(testData, 2))
+ }
+
+ datasetPoissonLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "log"), 2))
+
+ datasetPoissonIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "identity"), 2))
+
+ datasetPoissonSqrt = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "poisson", link = "sqrt"), 2))
+
+ datasetGammaInverse = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "inverse"), 2))
+
+ datasetGammaIdentity = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "identity"), 2))
+
+ datasetGammaLog = sqlContext.createDataFrame(
+ sc.parallelize(generateGeneralizedLinearRegressionInput(
+ intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5),
+ xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01,
+ family = "gamma", link = "log"), 2))
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new GeneralizedLinearRegression)
+ val model = new GeneralizedLinearRegressionModel("genLinReg", Vectors.dense(0.0), 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("generalized linear regression: default params") {
+ val glr = new GeneralizedLinearRegression
+ assert(glr.getLabelCol === "label")
+ assert(glr.getFeaturesCol === "features")
+ assert(glr.getPredictionCol === "prediction")
+ assert(glr.getFitIntercept)
+ assert(glr.getTol === 1E-6)
+ assert(glr.getWeightCol === "")
+ assert(glr.getRegParam === 0.0)
+ assert(glr.getSolver == "irls")
+ // TODO: Construct model directly instead of via fitting.
+ val model = glr.setFamily("gaussian").setLink("identity")
+ .fit(datasetGaussianIdentity)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ assert(model.getFamily === "gaussian")
+ assert(model.getLink === "identity")
+ }
+
+ test("generalized linear regression: gaussian family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="gaussian", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2960999 0.8087933
+ [1] 2.5002642 2.2000403 0.5999485
+
+ data <- read.csv("path", header=FALSE)
+ model1 <- glm(f1, family=gaussian(link=log), data=data, start=c(0,0))
+ model2 <- glm(f2, family=gaussian(link=log), data=data, start=c(0,0,0))
+ print(as.vector(coef(model1)))
+ print(as.vector(coef(model2)))
+
+ [1] 0.23069326 0.07993778
+ [1] 0.25001858 0.22002452 0.05998789
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=gaussian(link=inverse), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.3010179 0.8198976
+ [1] 2.4108902 2.2130248 0.6086152
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, 2.2960999, 0.8087933),
+ Vectors.dense(2.5002642, 2.2000403, 0.5999485),
+ Vectors.dense(0.0, 0.23069326, 0.07993778),
+ Vectors.dense(0.25001858, 0.22002452, 0.05998789),
+ Vectors.dense(0.0, 2.3010179, 0.8198976),
+ Vectors.dense(2.4108902, 2.2130248, 0.6086152))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("identity", datasetGaussianIdentity), ("log", datasetGaussianLog),
+ ("inverse", datasetGaussianInverse))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gaussian").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Gaussian, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gaussian family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: gaussian family against glmnet") {
+ /*
+ R code:
+ library(glmnet)
+ data <- read.csv("path", header=FALSE)
+ label = data$V1
+ features = as.matrix(data.frame(data$V2, data$V3))
+ for (intercept in c(FALSE, TRUE)) {
+ for (lambda in c(0.0, 0.1, 1.0)) {
+ model <- glmnet(features, label, family="gaussian", intercept=intercept,
+ lambda=lambda, alpha=0, thresh=1E-14)
+ print(as.vector(coef(model)))
+ }
+ }
+
+ [1] 0.0000000 2.2961005 0.8087932
+ [1] 0.0000000 2.2130368 0.8309556
+ [1] 0.0000000 1.7176137 0.9610657
+ [1] 2.5002642 2.2000403 0.5999485
+ [1] 3.1106389 2.0935142 0.5712711
+ [1] 6.7597127 1.4581054 0.3994266
+ */
+
+ val expected = Seq(
+ Vectors.dense(0.0, 2.2961005, 0.8087932),
+ Vectors.dense(0.0, 2.2130368, 0.8309556),
+ Vectors.dense(0.0, 1.7176137, 0.9610657),
+ Vectors.dense(2.5002642, 2.2000403, 0.5999485),
+ Vectors.dense(3.1106389, 2.0935142, 0.5712711),
+ Vectors.dense(6.7597127, 1.4581054, 0.3994266))
+
+ var idx = 0
+ for (fitIntercept <- Seq(false, true);
+ regParam <- Seq(0.0, 0.1, 1.0)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gaussian")
+ .setFitIntercept(fitIntercept).setRegParam(regParam)
+ val model = trainer.fit(datasetGaussianIdentity)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gaussian family, " +
+ s"fitIntercept = $fitIntercept and regParam = $regParam.")
+
+ idx += 1
+ }
+ }
+
+ test("generalized linear regression: binomial family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3 + data$V4 + data$V5
+ data <- read.csv("path", header=FALSE)
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="binomial", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.3560284 1.3010002 -0.3570805 -0.7406762
+ [1] 2.8367406 -0.5896187 0.8931655 -0.3925169 -0.7996989
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=binomial(link=probit), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.2134390 0.7800646 -0.2144267 -0.4438358
+ [1] 1.6995366 -0.3524694 0.5332651 -0.2352985 -0.4780850
+
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=binomial(link=cloglog), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] -0.2832198 0.8434144 -0.2524727 -0.5293452
+ [1] 1.5063590 -0.4038015 0.6133664 -0.2687882 -0.5541758
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, -0.3560284, 1.3010002, -0.3570805, -0.7406762),
+ Vectors.dense(2.8367406, -0.5896187, 0.8931655, -0.3925169, -0.7996989),
+ Vectors.dense(0.0, -0.2134390, 0.7800646, -0.2144267, -0.4438358),
+ Vectors.dense(1.6995366, -0.3524694, 0.5332651, -0.2352985, -0.4780850),
+ Vectors.dense(0.0, -0.2832198, 0.8434144, -0.2524727, -0.5293452),
+ Vectors.dense(1.5063590, -0.4038015, 0.6133664, -0.2687882, -0.5541758))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("logit", datasetBinomial), ("probit", datasetBinomial),
+ ("cloglog", datasetBinomial))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("binomial").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1),
+ model.coefficients(2), model.coefficients(3))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with binomial family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Binomial, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"binomial family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: poisson family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="poisson", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 0.22999393 0.08047088
+ [1] 0.25022353 0.21998599 0.05998621
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=poisson(link=identity), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2929501 0.8119415
+ [1] 2.5012730 2.1999407 0.5999107
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=poisson(link=sqrt), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2958947 0.8090515
+ [1] 2.5000480 2.1999972 0.5999968
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, 0.22999393, 0.08047088),
+ Vectors.dense(0.25022353, 0.21998599, 0.05998621),
+ Vectors.dense(0.0, 2.2929501, 0.8119415),
+ Vectors.dense(2.5012730, 2.1999407, 0.5999107),
+ Vectors.dense(0.0, 2.2958947, 0.8090515),
+ Vectors.dense(2.5000480, 2.1999972, 0.5999968))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("log", datasetPoissonLog), ("identity", datasetPoissonIdentity),
+ ("sqrt", datasetPoissonSqrt))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Poisson, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"poisson family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+
+ test("generalized linear regression: gamma family against glm") {
+ /*
+ R code:
+ f1 <- data$V1 ~ data$V2 + data$V3 - 1
+ f2 <- data$V1 ~ data$V2 + data$V3
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family="Gamma", data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.3392419 0.8058058
+ [1] 2.3507700 2.2533574 0.6042991
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=Gamma(link=identity), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 2.2908883 0.8147796
+ [1] 2.5002406 2.1998346 0.6000059
+
+ data <- read.csv("path", header=FALSE)
+ for (formula in c(f1, f2)) {
+ model <- glm(formula, family=Gamma(link=log), data=data)
+ print(as.vector(coef(model)))
+ }
+
+ [1] 0.22958970 0.08091066
+ [1] 0.25003210 0.21996957 0.06000215
+ */
+ val expected = Seq(
+ Vectors.dense(0.0, 2.3392419, 0.8058058),
+ Vectors.dense(2.3507700, 2.2533574, 0.6042991),
+ Vectors.dense(0.0, 2.2908883, 0.8147796),
+ Vectors.dense(2.5002406, 2.1998346, 0.6000059),
+ Vectors.dense(0.0, 0.22958970, 0.08091066),
+ Vectors.dense(0.25003210, 0.21996957, 0.06000215))
+
+ import GeneralizedLinearRegression._
+
+ var idx = 0
+ for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
+ ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
+ for (fitIntercept <- Seq(false, true)) {
+ val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
+ .setFitIntercept(fitIntercept)
+ val model = trainer.fit(dataset)
+ val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with gamma family, " +
+ s"$link link and fitIntercept = $fitIntercept.")
+
+ val familyLink = new FamilyAndLink(Gamma, Link.fromName(link))
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val eta = BLAS.dot(features, model.coefficients) + model.intercept
+ val prediction2 = familyLink.fitted(eta)
+ assert(prediction1 ~= prediction2 relTol 1E-5, "Prediction mismatch: GLM with " +
+ s"gamma family, $link link and fitIntercept = $fitIntercept.")
+ }
+
+ idx += 1
+ }
+ }
+ }
+}
+
+object GeneralizedLinearRegressionSuite {
+
+ def generateGeneralizedLinearRegressionInput(
+ intercept: Double,
+ coefficients: Array[Double],
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ noiseLevel: Double,
+ family: String,
+ link: String): Seq[LabeledPoint] = {
+
+ val rnd = new Random(seed)
+ def rndElement(i: Int) = {
+ (rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ }
+ val (generator, mean) = family match {
+ case "gaussian" => (new StandardNormalGenerator, 0.0)
+ case "poisson" => (new PoissonGenerator(1.0), 1.0)
+ case "gamma" => (new GammaGenerator(1.0, 1.0), 1.0)
+ }
+ generator.setSeed(seed)
+
+ (0 until nPoints).map { _ =>
+ val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray)
+ val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept
+ val mu = link match {
+ case "identity" => eta
+ case "log" => math.exp(eta)
+ case "sqrt" => math.pow(eta, 2.0)
+ case "inverse" => 1.0 / eta
+ }
+ val label = mu + noiseLevel * (generator.nextValue() - mean)
+ // Return LabeledPoints with DenseVector
+ LabeledPoint(label, features)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
index 77a2773c36f5..dcb1f398b04b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
@@ -42,6 +42,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
.collect()
/* Verify results using the `R` code:
+ library(arules)
transactions = as(sapply(
list("r z h k p",
"z y x w v u t s",
@@ -52,7 +53,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
FUN=function(x) strsplit(x," ",fixed=TRUE)),
"transactions")
ars = apriori(transactions,
- parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+ parameter = list(support = 0.5, confidence = 0.9, target="rules", minlen=2))
arsDF = as(ars, "data.frame")
arsDF$support = arsDF$support * length(transactions)
names(arsDF)[names(arsDF) == "support"] = "freq"
diff --git a/pom.xml b/pom.xml
index 2376e307ced1..2148379896d3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -89,7 +89,8 @@
common/sketch
common/network-common
common/network-shuffle
- tags
+ common/unsafe
+ common/tags
core
graphx
mllib
@@ -99,7 +100,6 @@
sql/core
sql/hive
docker-integration-tests
- unsafe
assembly
external/twitter
external/flume
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 3179fb30ab4d..253af15cb5cd 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -26,11 +26,12 @@
from pyspark.mllib.common import inherit_doc
-__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
- 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
- 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes',
- 'NaiveBayesModel', 'MultilayerPerceptronClassifier',
- 'MultilayerPerceptronClassificationModel']
+__all__ = ['LogisticRegression', 'LogisticRegressionModel',
+ 'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
+ 'GBTClassifier', 'GBTClassificationModel',
+ 'RandomForestClassifier', 'RandomForestClassificationModel',
+ 'NaiveBayes', 'NaiveBayesModel',
+ 'MultilayerPerceptronClassifier', 'MultilayerPerceptronClassificationModel']
@inherit_doc
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 611b9190491c..1cea477acb47 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -21,7 +21,8 @@
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
-__all__ = ['KMeans', 'KMeansModel', 'BisectingKMeans', 'BisectingKMeansModel']
+__all__ = ['BisectingKMeans', 'BisectingKMeansModel',
+ 'KMeans', 'KMeansModel']
class KMeansModel(JavaModel, MLWritable, MLReadable):
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 369f3508fda5..fb31c7310c0a 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -27,15 +27,34 @@
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
-__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT',
- 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString',
- 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinMaxScaler', 'MinMaxScalerModel',
- 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion',
- 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula', 'RFormulaModel',
- 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover',
- 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler',
- 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel', 'ChiSqSelector',
- 'ChiSqSelectorModel']
+__all__ = ['Binarizer',
+ 'Bucketizer',
+ 'ChiSqSelector', 'ChiSqSelectorModel',
+ 'CountVectorizer', 'CountVectorizerModel',
+ 'DCT',
+ 'ElementwiseProduct',
+ 'HashingTF',
+ 'IDF', 'IDFModel',
+ 'IndexToString',
+ 'MaxAbsScaler', 'MaxAbsScalerModel',
+ 'MinMaxScaler', 'MinMaxScalerModel',
+ 'NGram',
+ 'Normalizer',
+ 'OneHotEncoder',
+ 'PCA', 'PCAModel',
+ 'PolynomialExpansion',
+ 'QuantileDiscretizer',
+ 'RegexTokenizer',
+ 'RFormula', 'RFormulaModel',
+ 'SQLTransformer',
+ 'StandardScaler', 'StandardScalerModel',
+ 'StopWordsRemover',
+ 'StringIndexer', 'StringIndexerModel',
+ 'Tokenizer',
+ 'VectorAssembler',
+ 'VectorIndexer', 'VectorIndexerModel',
+ 'VectorSlicer',
+ 'Word2Vec', 'Word2VecModel']
@inherit_doc
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index b4d54ef61b0e..57106f8690a7 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -294,7 +294,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
(default: 0.01)
:param regType:
The type of regularizer used for training our model.
- Allowed values:
+ Supported values:
- "l1" for using L1 regularization
- "l2" for using L2 regularization (default)
@@ -326,7 +326,7 @@ class LogisticRegressionWithLBFGS(object):
"""
@classmethod
@since('1.2.0')
- def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
+ def train(cls, data, iterations=100, initialWeights=None, regParam=0.0, regType="l2",
intercept=False, corrections=10, tolerance=1e-6, validateData=True, numClasses=2):
"""
Train a logistic regression model on the given data.
@@ -341,10 +341,10 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
(default: None)
:param regParam:
The regularizer parameter.
- (default: 0.01)
+ (default: 0.0)
:param regType:
The type of regularizer used for training our model.
- Allowed values:
+ Supported values:
- "l1" for using L1 regularization
- "l2" for using L2 regularization (default)
@@ -356,7 +356,9 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
(default: False)
:param corrections:
The number of corrections used in the LBFGS update.
- (default: 10)
+ If a known updater is used for binary classification,
+ it calls the ml implementation and this parameter will
+ have no effect. (default: 10)
:param tolerance:
The convergence tolerance of iterations for L-BFGS.
(default: 1e-6)
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 4dd7083d79c8..3b77a6200054 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -37,10 +37,11 @@ class LabeledPoint(object):
"""
Class that represents the features and labels of a data point.
- :param label: Label for this data point.
- :param features: Vector of features for this point (NumPy array,
- list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
- column matrix)
+ :param label:
+ Label for this data point.
+ :param features:
+ Vector of features for this point (NumPy array, list,
+ pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix).
Note: 'label' and 'features' are accessible as class attributes.
@@ -66,8 +67,10 @@ class LinearModel(object):
"""
A linear model that has a vector of coefficients and an intercept.
- :param weights: Weights computed for every feature.
- :param intercept: Intercept computed for this model.
+ :param weights:
+ Weights computed for every feature.
+ :param intercept:
+ Intercept computed for this model.
.. versionadded:: 0.9.0
"""
@@ -217,19 +220,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
class LinearRegressionWithSGD(object):
"""
- Train a linear regression model with no regularization using Stochastic Gradient Descent.
- This solves the least squares regression formulation
-
- f(weights) = 1/n ||A weights-y||^2
-
- which is the mean squared error.
- Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
- its corresponding right hand side label y.
- See also the documentation for the precise formulation.
-
.. versionadded:: 0.9.0
"""
-
@classmethod
@since("0.9.0")
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
@@ -237,47 +229,52 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
validateData=True, convergenceTol=0.001):
"""
Train a linear regression model using Stochastic Gradient
- Descent (SGD).
- This solves the least squares regression formulation
-
- f(weights) = 1/(2n) ||A weights - y||^2,
-
- which is the mean squared error.
- Here the data matrix has n rows, and the input RDD holds the
- set of rows of A, each with its corresponding right hand side
- label y. See also the documentation for the precise formulation.
-
- :param data: The training data, an RDD of
- LabeledPoint.
- :param iterations: The number of iterations
- (default: 100).
- :param step: The step parameter used in SGD
- (default: 1.0).
- :param miniBatchFraction: Fraction of data to be used for each
- SGD iteration (default: 1.0).
- :param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter
- (default: 0.0).
- :param regType: The type of regularizer used for
- training our model.
-
- :Allowed values:
- - "l1" for using L1 regularization (lasso),
- - "l2" for using L2 regularization (ridge),
- - None for no regularization
-
- (default: None)
-
- :param intercept: Boolean parameter which indicates the
- use or not of the augmented representation
- for training data (i.e. whether bias
- features are activated or not,
- default: False).
- :param validateData: Boolean parameter which indicates if
- the algorithm should validate data
- before training. (default: True)
- :param convergenceTol: A condition which decides iteration termination.
- (default: 0.001)
+ Descent (SGD). This solves the least squares regression
+ formulation
+
+ f(weights) = 1/(2n) ||A weights - y||^2
+
+ which is the mean squared error. Here the data matrix has n rows,
+ and the input RDD holds the set of rows of A, each with its
+ corresponding right hand side label y.
+ See also the documentation for the precise formulation.
+
+ :param data:
+ The training data, an RDD of LabeledPoint.
+ :param iterations:
+ The number of iterations.
+ (default: 100)
+ :param step:
+ The step parameter used in SGD.
+ (default: 1.0)
+ :param miniBatchFraction:
+ Fraction of data to be used for each SGD iteration.
+ (default: 1.0)
+ :param initialWeights:
+ The initial weights.
+ (default: None)
+ :param regParam:
+ The regularizer parameter.
+ (default: 0.0)
+ :param regType:
+ The type of regularizer used for training our model.
+ Supported values:
+
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization (default)
+ :param intercept:
+ Boolean parameter which indicates the use or not of the
+ augmented representation for training data (i.e., whether bias
+ features are activated or not).
+ (default: False)
+ :param validateData:
+ Boolean parameter which indicates if the algorithm should
+ validate data before training.
+ (default: True)
+ :param convergenceTol:
+ A condition which decides iteration termination.
+ (default: 0.001)
"""
def train(rdd, i):
return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
@@ -368,56 +365,53 @@ def load(cls, sc, path):
class LassoWithSGD(object):
"""
- Train a regression model with L1-regularization using Stochastic Gradient Descent.
- This solves the L1-regularized least squares regression formulation
-
- f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1
-
- Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
- its corresponding right hand side label y.
- See also the documentation for the precise formulation.
-
.. versionadded:: 0.9.0
"""
-
@classmethod
@since("0.9.0")
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None, intercept=False,
validateData=True, convergenceTol=0.001):
"""
- Train a regression model with L1-regularization using
- Stochastic Gradient Descent.
- This solves the l1-regularized least squares regression
- formulation
-
- f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1.
-
- Here the data matrix has n rows, and the input RDD holds the
- set of rows of A, each with its corresponding right hand side
- label y. See also the documentation for the precise formulation.
-
- :param data: The training data, an RDD of
- LabeledPoint.
- :param iterations: The number of iterations
- (default: 100).
- :param step: The step parameter used in SGD
- (default: 1.0).
- :param regParam: The regularizer parameter
- (default: 0.01).
- :param miniBatchFraction: Fraction of data to be used for each
- SGD iteration (default: 1.0).
- :param initialWeights: The initial weights (default: None).
- :param intercept: Boolean parameter which indicates the
- use or not of the augmented representation
- for training data (i.e. whether bias
- features are activated or not,
- default: False).
- :param validateData: Boolean parameter which indicates if
- the algorithm should validate data
- before training. (default: True)
- :param convergenceTol: A condition which decides iteration termination.
- (default: 0.001)
+ Train a regression model with L1-regularization using Stochastic
+ Gradient Descent. This solves the l1-regularized least squares
+ regression formulation
+
+ f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1
+
+ Here the data matrix has n rows, and the input RDD holds the set
+ of rows of A, each with its corresponding right hand side label y.
+ See also the documentation for the precise formulation.
+
+ :param data:
+ The training data, an RDD of LabeledPoint.
+ :param iterations:
+ The number of iterations.
+ (default: 100)
+ :param step:
+ The step parameter used in SGD.
+ (default: 1.0)
+ :param regParam:
+ The regularizer parameter.
+ (default: 0.01)
+ :param miniBatchFraction:
+ Fraction of data to be used for each SGD iteration.
+ (default: 1.0)
+ :param initialWeights:
+ The initial weights.
+ (default: None)
+ :param intercept:
+ Boolean parameter which indicates the use or not of the
+ augmented representation for training data (i.e. whether bias
+ features are activated or not).
+ (default: False)
+ :param validateData:
+ Boolean parameter which indicates if the algorithm should
+ validate data before training.
+ (default: True)
+ :param convergenceTol:
+ A condition which decides iteration termination.
+ (default: 0.001)
"""
def train(rdd, i):
return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
@@ -508,56 +502,53 @@ def load(cls, sc, path):
class RidgeRegressionWithSGD(object):
"""
- Train a regression model with L2-regularization using Stochastic Gradient Descent.
- This solves the L2-regularized least squares regression formulation
-
- f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2
-
- Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
- its corresponding right hand side label y.
- See also the documentation for the precise formulation.
-
.. versionadded:: 0.9.0
"""
-
@classmethod
@since("0.9.0")
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None, intercept=False,
validateData=True, convergenceTol=0.001):
"""
- Train a regression model with L2-regularization using
- Stochastic Gradient Descent.
- This solves the l2-regularized least squares regression
- formulation
-
- f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2.
-
- Here the data matrix has n rows, and the input RDD holds the
- set of rows of A, each with its corresponding right hand side
- label y. See also the documentation for the precise formulation.
-
- :param data: The training data, an RDD of
- LabeledPoint.
- :param iterations: The number of iterations
- (default: 100).
- :param step: The step parameter used in SGD
- (default: 1.0).
- :param regParam: The regularizer parameter
- (default: 0.01).
- :param miniBatchFraction: Fraction of data to be used for each
- SGD iteration (default: 1.0).
- :param initialWeights: The initial weights (default: None).
- :param intercept: Boolean parameter which indicates the
- use or not of the augmented representation
- for training data (i.e. whether bias
- features are activated or not,
- default: False).
- :param validateData: Boolean parameter which indicates if
- the algorithm should validate data
- before training. (default: True)
- :param convergenceTol: A condition which decides iteration termination.
- (default: 0.001)
+ Train a regression model with L2-regularization using Stochastic
+ Gradient Descent. This solves the l2-regularized least squares
+ regression formulation
+
+ f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2
+
+ Here the data matrix has n rows, and the input RDD holds the set
+ of rows of A, each with its corresponding right hand side label y.
+ See also the documentation for the precise formulation.
+
+ :param data:
+ The training data, an RDD of LabeledPoint.
+ :param iterations:
+ The number of iterations.
+ (default: 100)
+ :param step:
+ The step parameter used in SGD.
+ (default: 1.0)
+ :param regParam:
+ The regularizer parameter.
+ (default: 0.01)
+ :param miniBatchFraction:
+ Fraction of data to be used for each SGD iteration.
+ (default: 1.0)
+ :param initialWeights:
+ The initial weights.
+ (default: None)
+ :param intercept:
+ Boolean parameter which indicates the use or not of the
+ augmented representation for training data (i.e. whether bias
+ features are activated or not).
+ (default: False)
+ :param validateData:
+ Boolean parameter which indicates if the algorithm should
+ validate data before training.
+ (default: True)
+ :param convergenceTol:
+ A condition which decides iteration termination.
+ (default: 0.001)
"""
def train(rdd, i):
return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
@@ -572,12 +563,14 @@ class IsotonicRegressionModel(Saveable, Loader):
"""
Regression model for isotonic regression.
- :param boundaries: Array of boundaries for which predictions are
- known. Boundaries must be sorted in increasing order.
- :param predictions: Array of predictions associated to the
- boundaries at the same index. Results of isotonic
- regression and therefore monotone.
- :param isotonic: indicates whether this is isotonic or antitonic.
+ :param boundaries:
+ Array of boundaries for which predictions are known. Boundaries
+ must be sorted in increasing order.
+ :param predictions:
+ Array of predictions associated to the boundaries at the same
+ index. Results of isotonic regression and therefore monotone.
+ :param isotonic:
+ Indicates whether this is isotonic or antitonic.
>>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)]
>>> irm = IsotonicRegression.train(sc.parallelize(data))
@@ -628,7 +621,8 @@ def predict(self, x):
values with the same boundary then the same rules as in 2)
are used.
- :param x: Feature or RDD of Features to be labeled.
+ :param x:
+ Feature or RDD of Features to be labeled.
"""
if isinstance(x, RDD):
return x.map(lambda v: self.predict(v))
@@ -657,8 +651,8 @@ def load(cls, sc, path):
class IsotonicRegression(object):
"""
Isotonic regression.
- Currently implemented using parallelized pool adjacent violators algorithm.
- Only univariate (single feature) algorithm supported.
+ Currently implemented using parallelized pool adjacent violators
+ algorithm. Only univariate (single feature) algorithm supported.
Sequential PAV implementation based on:
@@ -684,8 +678,11 @@ def train(cls, data, isotonic=True):
"""
Train a isotonic regression model on the given data.
- :param data: RDD of (label, feature, weight) tuples.
- :param isotonic: Whether this is isotonic or antitonic.
+ :param data:
+ RDD of (label, feature, weight) tuples.
+ :param isotonic:
+ Whether this is isotonic (which is default) or antitonic.
+ (default: True)
"""
boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel",
data.map(_convert_to_vector), bool(isotonic))
@@ -721,9 +718,11 @@ def _validate(self, dstream):
@since("1.5.0")
def predictOn(self, dstream):
"""
- Make predictions on a dstream.
+ Use the model to make predictions on batches of data from a
+ DStream.
- :return: Transformed dstream object.
+ :return:
+ DStream containing predictions.
"""
self._validate(dstream)
return dstream.map(lambda x: self._model.predict(x))
@@ -731,9 +730,11 @@ def predictOn(self, dstream):
@since("1.5.0")
def predictOnValues(self, dstream):
"""
- Make predictions on a keyed dstream.
+ Use the model to make predictions on the values of a DStream and
+ carry over its keys.
- :return: Transformed dstream object.
+ :return:
+ DStream containing the input keys and the predictions as values.
"""
self._validate(dstream)
return dstream.mapValues(lambda x: self._model.predict(x))
@@ -742,14 +743,15 @@ def predictOnValues(self, dstream):
@inherit_doc
class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
"""
- Train or predict a linear regression model on streaming data. Training uses
- Stochastic Gradient Descent to update the model based on each new batch of
- incoming data from a DStream (see `LinearRegressionWithSGD` for model equation).
+ Train or predict a linear regression model on streaming data.
+ Training uses Stochastic Gradient Descent to update the model
+ based on each new batch of incoming data from a DStream
+ (see `LinearRegressionWithSGD` for model equation).
Each batch of data is assumed to be an RDD of LabeledPoints.
The number of data points per batch can vary, but the number
- of features must be constant. An initial weight
- vector must be provided.
+ of features must be constant. An initial weight vector must
+ be provided.
:param stepSize:
Step size for each iteration of gradient descent.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b1453c637f79..7f5368d8bdbb 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -233,6 +233,23 @@ def text(self, paths):
paths = [paths]
return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths)))
+ @since(2.0)
+ def csv(self, paths):
+ """Loads a CSV file and returns the result as a [[DataFrame]].
+
+ This function goes through the input once to determine the input schema. To avoid going
+ through the entire data once, specify the schema explicitly using [[schema]].
+
+ :param paths: string, or list of strings, for input path(s).
+
+ >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv')
+ >>> df.dtypes
+ [('C0', 'string'), ('C1', 'string')]
+ """
+ if isinstance(paths, basestring):
+ paths = [paths]
+ return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths)))
+
@since(1.5)
def orc(self, path):
"""Loads an ORC file, returning the result as a :class:`DataFrame`.
@@ -448,6 +465,11 @@ def json(self, path, mode=None):
* ``ignore``: Silently ignore this operation if data already exists.
* ``error`` (default case): Throw an exception if data already exists.
+ You can set the following JSON-specific option(s) for writing JSON files:
+ * ``compression`` (default ``None``): compression codec to use when saving to file.
+ This can be one of the known case-insensitive shorten names
+ (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
+
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)._jwrite.json(path)
@@ -476,11 +498,39 @@ def parquet(self, path, mode=None, partitionBy=None):
def text(self, path):
"""Saves the content of the DataFrame in a text file at the specified path.
+ :param path: the path in any Hadoop supported file system
+
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
+
+ You can set the following option(s) for writing text files:
+ * ``compression`` (default ``None``): compression codec to use when saving to file.
+ This can be one of the known case-insensitive shorten names
+ (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
"""
self._jwrite.text(path)
+ @since(2.0)
+ def csv(self, path, mode=None):
+ """Saves the content of the [[DataFrame]] in CSV format at the specified path.
+
+ :param path: the path in any Hadoop supported file system
+ :param mode: specifies the behavior of the save operation when data already exists.
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
+
+ You can set the following CSV-specific option(s) for writing CSV files:
+ * ``compression`` (default ``None``): compression codec to use when saving to file.
+ This can be one of the known case-insensitive shorten names
+ (``bzip2``, ``gzip``, ``lz4``, and ``snappy``).
+
+ >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ self.mode(mode)._jwrite.csv(path)
+
@since(1.5)
def orc(self, path, mode=None, partitionBy=None):
"""Saves the content of the :class:`DataFrame` in ORC format at the specified path.
diff --git a/python/test_support/sql/ages.csv b/python/test_support/sql/ages.csv
new file mode 100644
index 000000000000..18991feda788
--- /dev/null
+++ b/python/test_support/sql/ages.csv
@@ -0,0 +1,4 @@
+Joe,20
+Tom,30
+Hyukjin,25
+
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 27ae62f1212f..0ad0f4976c77 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -36,7 +36,7 @@
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
-final class UnsafeExternalRowSorter {
+public final class UnsafeExternalRowSorter {
/**
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
@@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) {
testSpillFrequency = frequency;
}
- @VisibleForTesting
- void insertRow(UnsafeRow row) throws IOException {
+ public void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
@@ -110,8 +109,7 @@ private void cleanupResources() {
sorter.cleanupResources();
}
- @VisibleForTesting
- Iterator sort() throws IOException {
+ public Iterator sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -160,7 +158,6 @@ public UnsafeRow next() {
}
}
-
public Iterator sort(Iterator inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 23e4709bbd88..876aa0eae0e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import java.lang.reflect.Modifier
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@@ -559,7 +561,13 @@ class Analyzer(
}
resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
- case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
+ case n: NewInstance
+ // If this is an inner class of another class, register the outer object in `OuterScopes`.
+ // Note that static inner classes (e.g., inner classes within Scala objects) don't need
+ // outer pointer registration.
+ if n.outerPointer.isEmpty &&
+ n.cls.isMemberClass &&
+ !Modifier.isStatic(n.cls.getModifiers) =>
val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
if (outer == null) {
throw new AnalysisException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 4be065b30a21..3ee19cc4ad71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.text.DecimalFormat
+import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
@@ -938,8 +938,10 @@ case class FormatNumber(x: Expression, d: Expression)
@transient
private val pattern: StringBuffer = new StringBuffer()
+ // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
+ // as a decimal separator.
@transient
- private val numberFormat: DecimalFormat = new DecimalFormat("")
+ private val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
val dValue = dObject.asInstanceOf[Int]
@@ -962,10 +964,9 @@ case class FormatNumber(x: Expression, d: Expression)
pattern.append("0")
}
}
- val dFormat = new DecimalFormat(pattern.toString)
lastDValue = dValue
- numberFormat.applyPattern(dFormat.toPattern)
+ numberFormat.applyLocalizedPattern(pattern.toString)
}
x.dataType match {
@@ -992,6 +993,11 @@ case class FormatNumber(x: Expression, d: Expression)
val sb = classOf[StringBuffer].getName
val df = classOf[DecimalFormat].getName
+ val dfs = classOf[DecimalFormatSymbols].getName
+ val l = classOf[Locale].getName
+ // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
+ // as a decimal separator.
+ val usLocale = "US"
val lastDValue = ctx.freshName("lastDValue")
val pattern = ctx.freshName("pattern")
val numberFormat = ctx.freshName("numberFormat")
@@ -999,7 +1005,8 @@ case class FormatNumber(x: Expression, d: Expression)
val dFormat = ctx.freshName("dFormat")
ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;")
ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
- ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""")
+ ctx.addMutableState(df, numberFormat,
+ s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""")
s"""
if ($d >= 0) {
@@ -1013,9 +1020,8 @@ case class FormatNumber(x: Expression, d: Expression)
$pattern.append("0");
}
}
- $df $dFormat = new $df($pattern.toString());
$lastDValue = $d;
- $numberFormat.applyPattern($dFormat.toPattern());
+ $numberFormat.applyLocalizedPattern($pattern.toString());
}
${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8095083f336e..31e775d60f95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
+ /**
+ * Generates an additional set of aliased constraints by replacing the original constraint
+ * expressions with the corresponding alias
+ */
+ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
+ projectList.flatMap {
+ case a @ Alias(e, _) =>
+ child.constraints.map(_ transform {
+ case expr: Expression if expr.semanticEquals(e) =>
+ a.toAttribute
+ }).union(Set(EqualNullSafe(e, a.toAttribute)))
+ case _ =>
+ Set.empty[Expression]
+ }.toSet
+ }
+
override protected def validConstraints: Set[Expression] = child.constraints
override def statistics: Statistics = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5d2a65b716b2..e81a0f948746 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}
- /**
- * Generates an additional set of aliased constraints by replacing the original constraint
- * expressions with the corresponding alias
- */
- private def getAliasedConstraints: Set[Expression] = {
- projectList.flatMap {
- case a @ Alias(e, _) =>
- child.constraints.map(_ transform {
- case expr: Expression if expr.semanticEquals(e) =>
- a.toAttribute
- }).union(Set(EqualNullSafe(e, a.toAttribute)))
- case _ =>
- Set.empty[Expression]
- }.toSet
- }
-
- override def validConstraints: Set[Expression] = {
- child.constraints.union(getAliasedConstraints)
- }
+ override def validConstraints: Set[Expression] =
+ child.constraints.union(getAliasedConstraints(projectList))
}
/**
@@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
- override protected def validConstraints: Set[Expression] = {
+ override protected def validConstraints: Set[Expression] =
child.constraints.union(splitConjunctivePredicates(condition).toSet)
- }
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
- override protected def validConstraints: Set[Expression] = {
+ override protected def validConstraints: Set[Expression] =
leftConstraints.union(rightConstraints)
- }
// Intersect are only resolved if they don't introduce ambiguous expression ids,
// since the Optimizer will convert Intersect to Join.
@@ -442,6 +423,9 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows
+ override def validConstraints: Set[Expression] =
+ child.constraints.union(getAliasedConstraints(aggregateExpressions))
+
override def statistics: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(sizeInBytes = 1)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 38ce1604b1ed..6a59e9728a9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -340,6 +340,9 @@ object Decimal {
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
+ /** Maximum number of decimal digits a Int can represent */
+ val MAX_INT_DIGITS = 9
+
/** Maximum number of decimal digits a Long can represent */
val MAX_LONG_DIGITS = 18
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 2e03ddae760b..9c1319c1c5e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType {
}
}
+ /**
+ * Returns if dt is a DecimalType that fits inside a int
+ */
+ def is32BitDecimalType(dt: DataType): Boolean = {
+ dt match {
+ case t: DecimalType =>
+ t.precision <= Decimal.MAX_INT_DIGITS
+ case _ => false
+ }
+ }
+
/**
* Returns if dt is a DecimalType that fits inside a long
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 373b1ffa83d2..b68432b1a128 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "c"))))
}
+ test("propagating constraints in aggregate") {
+ val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
+ .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze
+
+ verifyConstraints(aliasedRelation.analyze.constraints,
+ Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
+ resolveColumn(aliasedRelation.analyze, "a") < 5,
+ IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
+ }
+
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
index e7f0ec2e7789..57dbd7c2ff56 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
@@ -257,8 +257,7 @@ private void initializeInternal() throws IOException {
throw new IOException("Unsupported type: " + t);
}
if (originalTypes[i] == OriginalType.DECIMAL &&
- primitiveType.getDecimalMetadata().getPrecision() >
- CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
+ primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
throw new IOException("Decimal with high precision is not supported.");
}
if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
@@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept
PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
int precision = type.getDecimalMetadata().getPrecision();
int scale = type.getDecimalMetadata().getScale();
- Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
+ Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
"Unsupported precision.");
for (int n = 0; n < num; ++n) {
@@ -480,11 +479,6 @@ private final class ColumnReader {
*/
private boolean useDictionary;
- /**
- * If useDictionary is true, the staging vector used to decode the ids.
- */
- private ColumnVector dictionaryIds;
-
/**
* Maximum definition level for this column.
*/
@@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException {
}
int num = Math.min(total, leftInPage);
if (useDictionary) {
- // Data is dictionary encoded. We will vector decode the ids and then resolve the values.
- if (dictionaryIds == null) {
- dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
- } else {
- dictionaryIds.reset();
- dictionaryIds.reserve(total);
- }
// Read and decode dictionary ids.
+ ColumnVector dictionaryIds = column.reserveDictionaryIds(total);;
defColumn.readIntegers(
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
- decodeDictionaryIds(rowId, num, column);
+ decodeDictionaryIds(rowId, num, column, dictionaryIds);
} else {
+ column.setDictionary(null);
switch (descriptor.getType()) {
case BOOLEAN:
readBooleanBatch(rowId, num, column);
@@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException {
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
- private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
+ private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
+ ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
- if (column.dataType() == DataTypes.IntegerType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (column.dataType() == DataTypes.ByteType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (column.dataType() == DataTypes.ShortType) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
- }
- } else {
- throw new NotImplementedException("Unimplemented type: " + column.dataType());
- }
- break;
-
case INT64:
- if (column.dataType() == DataTypes.LongType ||
- DecimalType.is64BitDecimalType(column.dataType())) {
- for (int i = rowId; i < rowId + num; ++i) {
- column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
- }
- } else {
- throw new NotImplementedException("Unimplemented type: " + column.dataType());
- }
- break;
-
case FLOAT:
- for (int i = rowId; i < rowId + num; ++i) {
- column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
- }
- break;
-
case DOUBLE:
- for (int i = rowId; i < rowId + num; ++i) {
- column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
- }
+ case BINARY:
+ column.setDictionary(dictionary);
break;
case FIXED_LEN_BYTE_ARRAY:
- if (DecimalType.is64BitDecimalType(column.dataType())) {
+ // DecimalType written in the legacy mode
+ if (DecimalType.is32BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
@@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
}
break;
- case BINARY:
- // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
- // need to do this better. We should probably add the dictionary data to the ColumnVector
- // and reuse it across batches. This should mean adding a ByteArray would just update
- // the length and offset.
- for (int i = rowId; i < rowId + num; ++i) {
- Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
- column.putByteArray(i, v.getBytes());
- }
- break;
-
default:
throw new NotImplementedException("Unsupported type: " + descriptor.getType());
}
@@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
+ if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
+ DecimalType.is32BitDecimalType(column.dataType())) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
- } else if (DecimalType.is64BitDecimalType(column.dataType())) {
- defColumn.readIntsAsLongs(
- num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ShortType) {
defColumn.readShorts(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num,
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
- if (DecimalType.is64BitDecimalType(column.dataType())) {
+ if (DecimalType.is32BitDecimalType(column.dataType())) {
+ for (int i = 0; i < num; i++) {
+ if (defColumn.readInteger() == maxDefLevel) {
+ column.putInt(rowId + i,
+ (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+ } else {
+ column.putNull(rowId + i);
+ }
+ }
+ } else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 8613fcae0b80..62157389013b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -25,7 +25,6 @@
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;
-import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
/**
@@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c,
}
}
- public void readIntsAsLongs(int total, ColumnVector c,
- int rowId, int level, VectorizedValuesReader data) {
- int left = total;
- while (left > 0) {
- if (this.currentCount == 0) this.readNextGroup();
- int n = Math.min(left, this.currentCount);
- switch (mode) {
- case RLE:
- if (currentValue == level) {
- for (int i = 0; i < n; i++) {
- c.putLong(rowId + i, data.readInteger());
- }
- } else {
- c.putNulls(rowId, n);
- }
- break;
- case PACKED:
- for (int i = 0; i < n; ++i) {
- if (currentBuffer[currentBufferIdx++] == level) {
- c.putLong(rowId + i, data.readInteger());
- } else {
- c.putNull(rowId + i);
- }
- }
- break;
- }
- rowId += n;
- left -= n;
- currentCount -= n;
- }
- }
-
public void readBytes(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 0514252a8e53..bb0247c2fbed 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -19,6 +19,10 @@
import java.math.BigDecimal;
import java.math.BigInteger;
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.parquet.column.Dictionary;
+import org.apache.parquet.io.api.Binary;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -27,8 +31,6 @@
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* This class represents a column of values and provides the main APIs to access the data
* values. It supports all the types and contains get/put APIs as well as their batched versions.
@@ -157,7 +159,7 @@ public Object[] array() {
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
- list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
+ list[i] = getUTF8String(i).toString();
}
}
} else if (dt instanceof CalendarIntervalType) {
@@ -204,28 +206,17 @@ public float getFloat(int ordinal) {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
- return Decimal.apply(getLong(ordinal), precision, scale);
- } else {
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(javaDecimal, precision, scale);
- }
+ return data.getDecimal(offset + ordinal, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
- Array child = data.getByteArray(offset + ordinal);
- return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length);
+ return data.getUTF8String(offset + ordinal);
}
@Override
public byte[] getBinary(int ordinal) {
- ColumnVector.Array array = data.getByteArray(offset + ordinal);
- byte[] bytes = new byte[array.length];
- System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
- return bytes;
+ return data.getBinary(offset + ordinal);
}
@Override
@@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) {
/**
* Returns the value for rowId.
*/
- public final Array getByteArray(int rowId) {
+ private Array getByteArray(int rowId) {
Array array = getArray(rowId);
array.data.loadBytes(array);
return array;
}
+ /**
+ * Returns the decimal for rowId.
+ */
+ public final Decimal getDecimal(int rowId, int precision, int scale) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ return Decimal.apply(getInt(rowId), precision, scale);
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.apply(getLong(rowId), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(rowId);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
+ }
+
+ /**
+ * Returns the UTF8String for rowId.
+ */
+ public final UTF8String getUTF8String(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array a = getByteArray(rowId);
+ return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
+ } else {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
+ return UTF8String.fromBytes(v.getBytes());
+ }
+ }
+
+ /**
+ * Returns the byte array for rowId.
+ */
+ public final byte[] getBinary(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array array = getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
+ } else {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
+ return v.getBytes();
+ }
+ }
+
/**
* Append APIs. These APIs all behave similarly and will append data to the current vector. It
* is not valid to mix the put and append APIs. The append APIs are slower and should only be
@@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) {
*/
protected final ColumnarBatch.Row resultStruct;
+ /**
+ * The Dictionary for this column.
+ *
+ * If it's not null, will be used to decode the value in getXXX().
+ */
+ protected Dictionary dictionary;
+
+ /**
+ * Reusable column for ids of dictionary.
+ */
+ protected ColumnVector dictionaryIds;
+
+ /**
+ * Update the dictionary.
+ */
+ public void setDictionary(Dictionary dictionary) {
+ this.dictionary = dictionary;
+ }
+
+ /**
+ * Reserve a integer column for ids of dictionary.
+ */
+ public ColumnVector reserveDictionaryIds(int capacity) {
+ if (dictionaryIds == null) {
+ dictionaryIds = allocate(capacity, DataTypes.IntegerType,
+ this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP);
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(capacity);
+ }
+ return dictionaryIds;
+ }
+
/**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 2aeef7f2f90f..681ace338713 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -22,24 +22,20 @@
import java.util.Iterator;
import java.util.List;
+import org.apache.commons.lang.NotImplementedException;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
* for debugging or other non-performance critical paths.
* These utilities are mostly used to convert ColumnVectors into other formats.
*/
public class ColumnVectorUtils {
- public static String toString(ColumnVector.Array a) {
- return new String(a.byteArray, a.byteArrayOffset, a.length);
- }
-
/**
* Returns the array data as the java primitive array.
* For example, an array of IntegerType will return an int[].
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 070d897a7158..8a0d7f8b1237 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,11 +16,11 @@
*/
package org.apache.spark.sql.execution.vectorized;
-import java.math.BigDecimal;
-import java.math.BigInteger;
import java.util.Arrays;
import java.util.Iterator;
+import org.apache.commons.lang.NotImplementedException;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
@@ -31,8 +31,6 @@
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.commons.lang.NotImplementedException;
-
/**
* This class is the in memory representation of rows as they are streamed through operators. It
* is designed to maximize CPU efficiency and not storage footprint. Since it is expected that
@@ -193,29 +191,17 @@ public final boolean anyNull() {
@Override
public final Decimal getDecimal(int ordinal, int precision, int scale) {
- if (precision <= Decimal.MAX_LONG_DIGITS()) {
- return Decimal.apply(getLong(ordinal), precision, scale);
- } else {
- // TODO: best perf?
- byte[] bytes = getBinary(ordinal);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(javaDecimal, precision, scale);
- }
+ return columns[ordinal].getDecimal(rowId, precision, scale);
}
@Override
public final UTF8String getUTF8String(int ordinal) {
- ColumnVector.Array a = columns[ordinal].getByteArray(rowId);
- return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
+ return columns[ordinal].getUTF8String(rowId);
}
@Override
public final byte[] getBinary(int ordinal) {
- ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
- byte[] bytes = new byte[array.length];
- System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
- return bytes;
+ return columns[ordinal].getBinary(rowId);
}
@Override
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index e38ed051219b..b06b7f2457b5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -18,25 +18,11 @@
import java.nio.ByteOrder;
-import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.execution.vectorized.ColumnVector.Array;
-import org.apache.spark.sql.types.BooleanType;
-import org.apache.spark.sql.types.ByteType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DateType;
-import org.apache.spark.sql.types.DecimalType;
-import org.apache.spark.sql.types.DoubleType;
-import org.apache.spark.sql.types.FloatType;
-import org.apache.spark.sql.types.IntegerType;
-import org.apache.spark.sql.types.LongType;
-import org.apache.spark.sql.types.ShortType;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.types.UTF8String;
-
-
import org.apache.commons.lang.NotImplementedException;
-import org.apache.commons.lang.NotImplementedException;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.Platform;
/**
* Column data backed using offheap memory.
@@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) {
@Override
public final byte getByte(int rowId) {
- return Platform.getByte(null, data + rowId);
+ if (dictionary == null) {
+ return Platform.getByte(null, data + rowId);
+ } else {
+ return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
@Override
public final short getShort(int rowId) {
- return Platform.getShort(null, data + 2 * rowId);
+ if (dictionary == null) {
+ return Platform.getShort(null, data + 2 * rowId);
+ } else {
+ return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI
@Override
public final int getInt(int rowId) {
- return Platform.getInt(null, data + 4 * rowId);
+ if (dictionary == null) {
+ return Platform.getInt(null, data + 4 * rowId);
+ } else {
+ return dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src
@Override
public final long getLong(int rowId) {
- return Platform.getLong(null, data + 8 * rowId);
+ if (dictionary == null) {
+ return Platform.getLong(null, data + 8 * rowId);
+ } else {
+ return dictionary.decodeToLong(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
@Override
public final float getFloat(int rowId) {
- return Platform.getFloat(null, data + rowId * 4);
+ if (dictionary == null) {
+ return Platform.getFloat(null, data + rowId * 4);
+ } else {
+ return dictionary.decodeToFloat(dictionaryIds.getInt(rowId));
+ }
}
@@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
@Override
public final double getDouble(int rowId) {
- return Platform.getDouble(null, data + rowId * 8);
+ if (dictionary == null) {
+ return Platform.getDouble(null, data + rowId * 8);
+ } else {
+ return dictionary.decodeToDouble(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) {
} else if (type instanceof ShortType) {
this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2);
} else if (type instanceof IntegerType || type instanceof FloatType ||
- type instanceof DateType) {
+ type instanceof DateType || DecimalType.is32BitDecimalType(type)) {
this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type)) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 3502d31bd1df..305e84a86bdc 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -16,13 +16,12 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.util.Arrays;
+
import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.execution.vectorized.ColumnVector.Array;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
-import java.util.Arrays;
-
/**
* A column backed by an in memory JVM array. This stores the NULLs as a byte per value
* and a java array for the values.
@@ -68,7 +67,6 @@ public final void close() {
doubleData = null;
}
-
//
// APIs dealing with nulls
//
@@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) {
@Override
public final byte getByte(int rowId) {
- return byteData[rowId];
+ if (dictionary == null) {
+ return byteData[rowId];
+ } else {
+ return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
@Override
public final short getShort(int rowId) {
- return shortData[rowId];
+ if (dictionary == null) {
+ return shortData[rowId];
+ } else {
+ return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
@@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI
@Override
public final int getInt(int rowId) {
- return intData[rowId];
+ if (dictionary == null) {
+ return intData[rowId];
+ } else {
+ return dictionary.decodeToInt(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src
@Override
public final long getLong(int rowId) {
- return longData[rowId];
+ if (dictionary == null) {
+ return longData[rowId];
+ } else {
+ return dictionary.decodeToLong(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
}
@Override
- public final float getFloat(int rowId) { return floatData[rowId]; }
+ public final float getFloat(int rowId) {
+ if (dictionary == null) {
+ return floatData[rowId];
+ } else {
+ return dictionary.decodeToFloat(dictionaryIds.getInt(rowId));
+ }
+ }
//
// APIs dealing with doubles
@@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
@Override
public final double getDouble(int rowId) {
- return doubleData[rowId];
+ if (dictionary == null) {
+ return doubleData[rowId];
+ } else {
+ return dictionary.decodeToDouble(dictionaryIds.getInt(rowId));
+ }
}
//
@@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) {
short[] newData = new short[newCapacity];
if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
shortData = newData;
- } else if (type instanceof IntegerType || type instanceof DateType) {
+ } else if (type instanceof IntegerType || type instanceof DateType ||
+ DecimalType.is32BitDecimalType(type)) {
int[] newData = new int[newCapacity];
if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
intData = newData;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index d6bdd3d82556..093504c765ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -453,6 +453,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* format("json").save(path)
* }}}
*
+ * You can set the following JSON-specific option(s) for writing JSON files:
+ * `compression` (default `null`): compression codec to use when saving to file. This can be
+ * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
+ *
* @since 1.4.0
*/
def json(path: String): Unit = format("json").save(path)
@@ -492,10 +496,29 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* df.write().text("/path/to/output")
* }}}
*
+ * You can set the following option(s) for writing text files:
+ * `compression` (default `null`): compression codec to use when saving to file. This can be
+ * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
+ *
* @since 1.6.0
*/
def text(path: String): Unit = format("text").save(path)
+ /**
+ * Saves the content of the [[DataFrame]] in CSV format at the specified path.
+ * This is equivalent to:
+ * {{{
+ * format("csv").save(path)
+ * }}}
+ *
+ * You can set the following CSV-specific option(s) for writing CSV files:
+ * `compression` (default `null`): compression codec to use when saving to file. This can be
+ * one of the known case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`).
+ *
+ * @since 2.0.0
+ */
+ def csv(path: String): Unit = format("csv").save(path)
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 75cb6d1137c3..2ea889ea72c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -37,7 +39,7 @@ case class Sort(
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
- extends UnaryNode {
+ extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -50,34 +52,36 @@ case class Sort(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
- protected override def doExecute(): RDD[InternalRow] = {
- val schema = child.schema
- val childOutput = child.output
+ def createSorter(): UnsafeExternalRowSorter = {
+ val ordering = newOrdering(sortOrder, output)
+
+ // The comparator for comparing prefix
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
+ val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+ // The generator for prefix
+ val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+ val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+ val sorter = new UnsafeExternalRowSorter(
+ schema, ordering, prefixComparator, prefixComputer, pageSize)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")
child.execute().mapPartitionsInternal { iter =>
- val ordering = newOrdering(sortOrder, childOutput)
-
- // The comparator for comparing prefix
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
- val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
-
- // The generator for prefix
- val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
- val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
- }
- }
-
- val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
- val sorter = new UnsafeExternalRowSorter(
- schema, ordering, prefixComparator, prefixComputer, pageSize)
- if (testSpillFrequency > 0) {
- sorter.setTestSpillFrequency(testSpillFrequency)
- }
+ val sorter = createSorter()
val metrics = TaskContext.get().taskMetrics()
// Remember spill data size of this task before execute this operator so that we can
@@ -93,4 +97,74 @@ case class Sort(
sortedIterator
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ // Name of sorter variable used in codegen.
+ private var sorterVariable: String = _
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ val needToSort = ctx.freshName("needToSort")
+ ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
+
+
+ // Initialize the class member variables. This includes the instance of the Sorter and
+ // the iterator to return sorted rows.
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ sorterVariable = ctx.freshName("sorter")
+ ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
+ s"$sorterVariable = $thisPlan.createSorter();")
+ val metrics = ctx.freshName("metrics")
+ ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
+ s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
+ val sortedIterator = ctx.freshName("sortedIter")
+ ctx.addMutableState("scala.collection.Iterator", sortedIterator, "")
+
+ val addToSorter = ctx.freshName("addToSorter")
+ ctx.addNewFunction(addToSorter,
+ s"""
+ | private void $addToSorter() throws java.io.IOException {
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ | }
+ """.stripMargin.trim)
+
+ val outputRow = ctx.freshName("outputRow")
+ val dataSize = metricTerm(ctx, "dataSize")
+ val spillSize = metricTerm(ctx, "spillSize")
+ val spillSizeBefore = ctx.freshName("spillSizeBefore")
+ s"""
+ | if ($needToSort) {
+ | $addToSorter();
+ | Long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | $sortedIterator = $sorterVariable.sort();
+ | $dataSize.add($sorterVariable.getPeakMemoryUsage());
+ | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
+ | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
+ | $needToSort = false;
+ | }
+ |
+ | while ($sortedIterator.hasNext()) {
+ | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
+ | ${consume(ctx, null, outputRow)}
+ | if (shouldStop()) return;
+ | }
+ """.stripMargin.trim
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+
+ s"""
+ | // Convert the input attributes to an UnsafeRow and add it to the sorter
+ | ${code.code}
+ | $sorterVariable.insertRow(${code.value});
+ """.stripMargin.trim
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index dd8c96d5fa1d..0255103b63d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -71,9 +71,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
- // no predicate can be evaluated by matching hash keys
- case logical.Join(left, right, LeftSemi, condition) =>
- joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index afaddcf35775..cb68ca6ada36 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
${code.trim}
}
}
- """
+ """.trim
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
@@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// There is an UnsafeRow already
s"""
|append($row.copy());
- """.stripMargin
+ """.stripMargin.trim
} else {
assert(input != null)
if (input.nonEmpty) {
@@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
s"""
|${code.code.trim}
|append(${code.value}.copy());
- """.stripMargin
+ """.stripMargin.trim
} else {
// There is no columns
s"""
|append(unsafeRow);
- """.stripMargin
+ """.stripMargin.trim
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index ace8cd7ad864..7f1ed28046b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._
-
private[csv] object CSVInferSchema {
/**
@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
- StructField(thisHeader, rootType, nullable = true)
+ val dType = rootType match {
+ case _: NullType => StringType
+ case other => other
+ }
+ StructField(thisHeader, dType, nullable = true)
}
StructType(structFields)
@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
}
def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
- first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
- val tpe = findTightestCommonType(a, b).getOrElse(StringType)
- tpe match {
- case _: NullType => StringType
- case other => other
- }
+ first.zipAll(second, NullType, NullType).map { case (a, b) =>
+ findTightestCommonType(a, b).getOrElse(NullType)
}
}
@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
+ case (StringType, t2) => Some(StringType)
+ case (t1, StringType) => Some(StringType)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
}
}
-
private[csv] object CSVTypeCast {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index ee6373d03e1f..9e336422d1f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -44,6 +44,12 @@ private[sql] object JDBCRelation {
* exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented.
+ *
+ * Null value predicate is added to the first partition where clause to include
+ * the rows with null value for the partitions column.
+ *
+ * @param partitioning partition information to generate the where clause for each partition
+ * @return an array of partitions with where clause for each partition
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
@@ -66,7 +72,7 @@ private[sql] object JDBCRelation {
if (upperBound == null) {
lowerBound
} else if (lowerBound == null) {
- upperBound
+ s"$upperBound or $column is null"
} else {
s"$lowerBound AND $upperBound"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
index 31a95ed46121..e59dbd6b3d43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
@@ -48,10 +48,7 @@ private[sql] class JSONOptions(
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
val allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
- val compressionCodec = {
- val name = parameters.get("compression").orElse(parameters.get("codec"))
- name.map(CompressionCodecs.getCodecClassName)
- }
+ val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 42d89f4bf81d..8a128b4b6176 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter(
}
protected def decimalFromBinary(value: Binary): Decimal = {
- if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
val unscaled = CatalystRowConverter.binaryToUnscaledLong(value)
Decimal(unscaled, precision, scale)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index ab4250d0adba..6f6340f541ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition._
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64}
+import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter(
case INT_16 => ShortType
case INT_32 | null => IntegerType
case DATE => DateType
- case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32)
+ case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS)
case UINT_8 => typeNotSupported()
case UINT_16 => typeNotSupported()
case UINT_32 => typeNotSupported()
@@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter(
case INT64 =>
originalType match {
case INT_64 | null => LongType
- case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64)
+ case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS)
case UINT_64 => typeNotSupported()
case TIMESTAMP_MILLIS => typeNotImplemented()
case _ => illegalType()
@@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT32 for 1 <= precision <= 9
case DecimalType.Fixed(precision, scale)
- if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat =>
+ if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat =>
Types
.primitive(INT32, repetition)
.as(DECIMAL)
@@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT64 for 1 <= precision <= 18
case DecimalType.Fixed(precision, scale)
- if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat =>
+ if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat =>
Types
.primitive(INT64, repetition)
.as(DECIMAL)
@@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter {
// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)
- val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */
-
- val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */
-
// Max precision of a decimal value stored in `numBytes` bytes
def maxPrecisionForBytes(numBytes: Int): Int = {
Math.round( // convert double to long
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
index 3508220c9541..0252c79d8e14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala
@@ -33,7 +33,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64}
+import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi
writeLegacyParquetFormat match {
// Standard mode, 1 <= precision <= 9, writes as INT32
- case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer
+ case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer
// Standard mode, 10 <= precision <= 18, writes as INT64
- case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer
+ case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer
// Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY
- case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong
+ case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong
// Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY
case _ => binaryWriterUsingUnscaledBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 60155b32349a..8f3f6335e428 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -115,10 +115,7 @@ private[sql] class TextRelation(
/** Write path. */
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = job.getConfiguration
- val compressionCodec = {
- val name = parameters.get("compression").orElse(parameters.get("codec"))
- name.map(CompressionCodecs.getCodecClassName)
- }
+ val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
deleted file mode 100644
index df6dac88187c..000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
- * for hash join.
- */
-case class LeftSemiJoinBNL(
- streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- override def output: Seq[Attribute] = left.output
-
- /** The Streamed Relation */
- override def left: SparkPlan = streamed
-
- /** The Broadcast relation */
- override def right: SparkPlan = broadcast
-
- override def requiredChildDistribution: Seq[Distribution] = {
- UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
- }
-
- @transient private lazy val boundCondition =
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
-
- streamed.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow
- val relation = broadcastedRelation.value
-
- streamedIter.filter(streamedRow => {
- var i = 0
- var matched = false
-
- while (i < relation.length && !matched) {
- if (boundCondition(joinedRow(streamedRow, relation(i)))) {
- matched = true
- }
- i += 1
- }
- if (matched) {
- numOutputRows += 1
- }
- matched
- })
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index cd543d419528..45175d36d5c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
/**
* Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
*/
-trait BaseLimit extends UnaryNode {
+trait BaseLimit extends UnaryNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val stopEarly = ctx.freshName("stopEarly")
+ ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
+
+ ctx.addNewFunction("shouldStop", s"""
+ @Override
+ protected boolean shouldStop() {
+ return !currentRows.isEmpty() || $stopEarly;
+ }
+ """)
+ val countTerm = ctx.freshName("count")
+ ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+ s"""
+ | if ($countTerm < $limit) {
+ | $countTerm += 1;
+ | ${consume(ctx, input)}
+ | } else {
+ | $stopEarly = true;
+ | }
+ """.stripMargin
+ }
}
/**
diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/simple_sparse.csv
new file mode 100644
index 000000000000..02d29cabf95f
--- /dev/null
+++ b/sql/core/src/test/resources/simple_sparse.csv
@@ -0,0 +1,5 @@
+A,B,C,D
+1,,,
+,1,,
+,,1,
+,,,1
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 14fc37b64aa3..33df6375e3aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -621,12 +621,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds.filter(_ => true),
Some(1), Some(2), Some(3))
}
+
+ test("SPARK-13540 Dataset of nested class defined in Scala object") {
+ checkAnswer(
+ Seq(OuterObject.InnerClass("foo")).toDS(),
+ OuterObject.InnerClass("foo"))
+ }
}
class OuterClass extends Serializable {
case class InnerClass(a: String)
}
+object OuterObject {
+ case class InnerClass(a: String)
+}
+
case class ClassData(a: String, b: Int)
case class ClassData2(c: String, d: Int)
case class ClassNullableData(a: String, b: Integer)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 3dab848e7b03..5b98c11ef2a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -47,7 +47,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val operators = physical.collect {
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
- case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
@@ -67,7 +66,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
- ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]),
@@ -465,7 +464,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2",
- classOf[LeftSemiJoinBNL]),
+ classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 6d6cc0186a96..2d3e34d0e129 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("range/limit/sum") {
+ val N = 500 << 20
+ runBenchmark("range/limit/sum", N) {
+ sqlContext.range(N).limit(1000000).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X
+ range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X
+ */
+ }
+
ignore("stat functions") {
val N = 100 << 20
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 9350205d791d..de371d85d9fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
+
+ test("Sort should be included in WholeStageCodegen") {
+ val df = sqlContext.range(3, 0, -1).sort(col("id"))
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+ assert(df.collect() === Array(Row(1), Row(2), Row(3)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index a1796f132600..412f1b89beee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}
+
+ test("Merging Nulltypes should yeild Nulltype.") {
+ val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
+ assert(mergedNullTypes.deep == Array(NullType).deep)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 7671bc106610..3ecbb14f2ea6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
+ private val simpleSparseFile = "simple_sparse.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema.fieldNames.size === 1)
}
-
test("DDL test with empty file") {
sqlContext.sql(s"""
|CREATE TEMPORARY TABLE carsTable
@@ -268,9 +268,8 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.load(testFile(carsFile))
cars.coalesce(1).write
- .format("csv")
.option("header", "true")
- .save(csvDir)
+ .csv(csvDir)
val carsCopy = sqlContext.read
.format("csv")
@@ -396,4 +395,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(carsCopy, withHeader = true)
}
}
+
+ test("Schema inference correctly identifies the datatype when data is sparse.") {
+ val df = sqlContext.read
+ .format("csv")
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .load(testFile(simpleSparseFile))
+
+ assert(
+ df.schema.fields.map(field => field.dataType).deep ==
+ Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
index cef6b79a094d..281a2cffa894 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala
@@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex
assert(batch.column(0).getByte(i) == 1)
assert(batch.column(1).getInt(i) == 2)
assert(batch.column(2).getLong(i) == 3)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc")
+ assert(batch.column(3).getUTF8String(i).toString == "abc")
i += 1
}
reader.close()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 355f916a9755..bc341db5571b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -95,15 +95,6 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
- test(s"$testName using LeftSemiJoinBNL") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- LeftSemiJoinBNL(left, right, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
-
test(s"$testName using BroadcastNestedLoopJoin build left") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index c49f2439fce4..5b4f6f1d2461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
+ test("Sort metrics") {
+ // Assume the execution plan is
+ // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
+ val df = sqlContext.range(10).sort('id)
+ testSparkPlanMetrics(df, 2, Map.empty)
+ }
+
test("SortMergeJoin metrics") {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
index 8efdf8adb042..97638a66ab47 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
@@ -370,7 +370,7 @@ object ColumnarBatchBenchmark {
}
i = 0
while (i < count) {
- sum += column.getByteArray(i).length
+ sum += column.getUTF8String(i).numBytes()
i += 1
}
column.reset()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 445f311107e3..b3c3e66fbcbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
reference.zipWithIndex.foreach { v =>
assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode)
- assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)),
+ assert(v._1 == column.getUTF8String(v._2).toString,
"MemoryMode" + memMode)
}
@@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(batch.column(1).getDouble(0) == 1.1)
assert(batch.column(1).getIsNull(0) == false)
assert(batch.column(2).getIsNull(0) == true)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello")
+ assert(batch.column(3).getUTF8String(0).toString == "Hello")
// Verify the iterator works correctly.
val it = batch.rowIterator()
@@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(row.getDouble(1) == 1.1)
assert(row.isNullAt(1) == false)
assert(row.isNullAt(2) == true)
- assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello")
+ assert(batch.column(3).getUTF8String(0).toString == "Hello")
assert(it.hasNext == false)
assert(it.hasNext == false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index f8a9a95c873a..30a5e2ea4acd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -171,6 +171,27 @@ class JDBCSuite extends SparkFunSuite
|OPTIONS (url '$url', dbtable 'TEST.NULLTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
+ conn.prepareStatement(
+ "create table test.emp(name TEXT(32) NOT NULL," +
+ " theid INTEGER, \"Dept\" INTEGER)").executeUpdate()
+ conn.prepareStatement(
+ "insert into test.emp values ('fred', 1, 10)").executeUpdate()
+ conn.prepareStatement(
+ "insert into test.emp values ('mary', 2, null)").executeUpdate()
+ conn.prepareStatement(
+ "insert into test.emp values ('joe ''foo'' \"bar\"', 3, 30)").executeUpdate()
+ conn.prepareStatement(
+ "insert into test.emp values ('kathy', null, null)").executeUpdate()
+ conn.commit()
+
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE nullparts
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.EMP', user 'testUser', password 'testPass',
+ |partitionColumn '"Dept"', lowerBound '1', upperBound '4', numPartitions '4')
+ """.stripMargin.replaceAll("\n", " "))
+
// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
}
@@ -338,6 +359,23 @@ class JDBCSuite extends SparkFunSuite
.collect().length === 3)
}
+ test("Partioning on column that might have null values.") {
+ assert(
+ sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties)
+ .collect().length === 4)
+ assert(
+ sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties)
+ .collect().length === 4)
+ // partitioning on a nullable quoted column
+ assert(
+ sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties)
+ .collect().length === 4)
+ }
+
+ test("SELECT * on partitioned table with a nullable partioncolumn") {
+ assert(sql("SELECT * FROM nullparts").collect().size == 4)
+ }
+
test("H2 integral types") {
val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect()
assert(rows.length === 1)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 4c9432dbd6ab..aef78fdfd4c5 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -18,7 +18,9 @@
package org.apache.spark.deploy.yarn
import java.io.File
+import java.lang.reflect.UndeclaredThrowableException
import java.nio.charset.StandardCharsets.UTF_8
+import java.security.PrivilegedExceptionAction
import java.util.regex.Matcher
import java.util.regex.Pattern
@@ -194,7 +196,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
*/
def obtainTokenForHiveMetastore(conf: Configuration): Option[Token[DelegationTokenIdentifier]] = {
try {
- obtainTokenForHiveMetastoreInner(conf, UserGroupInformation.getCurrentUser().getUserName)
+ obtainTokenForHiveMetastoreInner(conf)
} catch {
case e: ClassNotFoundException =>
logInfo(s"Hive class not found $e")
@@ -209,8 +211,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
* @param username the username of the principal requesting the delegating token.
* @return a delegation token
*/
- private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration,
- username: String): Option[Token[DelegationTokenIdentifier]] = {
+ private[yarn] def obtainTokenForHiveMetastoreInner(conf: Configuration):
+ Option[Token[DelegationTokenIdentifier]] = {
val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader)
// the hive configuration class is a subclass of Hadoop Configuration, so can be cast down
@@ -225,11 +227,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
// Check for local metastore
if (metastoreUri.nonEmpty) {
- require(username.nonEmpty, "Username undefined")
val principalKey = "hive.metastore.kerberos.principal"
val principal = hiveConf.getTrimmed(principalKey, "")
require(principal.nonEmpty, "Hive principal $principalKey undefined")
- logDebug(s"Getting Hive delegation token for $username against $principal at $metastoreUri")
+ val currentUser = UserGroupInformation.getCurrentUser()
+ logDebug(s"Getting Hive delegation token for ${currentUser.getUserName()} against " +
+ s"$principal at $metastoreUri")
val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive")
val closeCurrent = hiveClass.getMethod("closeCurrent")
try {
@@ -238,12 +241,14 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
classOf[String], classOf[String])
val getHive = hiveClass.getMethod("get", hiveConfClass)
- // invoke
- val hive = getHive.invoke(null, hiveConf)
- val tokenStr = getDelegationToken.invoke(hive, username, principal).asInstanceOf[String]
- val hive2Token = new Token[DelegationTokenIdentifier]()
- hive2Token.decodeFromUrlString(tokenStr)
- Some(hive2Token)
+ doAsRealUser {
+ val hive = getHive.invoke(null, hiveConf)
+ val tokenStr = getDelegationToken.invoke(hive, currentUser.getUserName(), principal)
+ .asInstanceOf[String]
+ val hive2Token = new Token[DelegationTokenIdentifier]()
+ hive2Token.decodeFromUrlString(tokenStr)
+ Some(hive2Token)
+ }
} finally {
Utils.tryLogNonFatalError {
closeCurrent.invoke(null)
@@ -303,6 +308,25 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
}
+ /**
+ * Run some code as the real logged in user (which may differ from the current user, for
+ * example, when using proxying).
+ */
+ private def doAsRealUser[T](fn: => T): T = {
+ val currentUser = UserGroupInformation.getCurrentUser()
+ val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser)
+
+ // For some reason the Scala-generated anonymous class ends up causing an
+ // UndeclaredThrowableException, even if you annotate the method with @throws.
+ try {
+ realUser.doAs(new PrivilegedExceptionAction[T]() {
+ override def run(): T = fn
+ })
+ } catch {
+ case e: UndeclaredThrowableException => throw Option(e.getCause()).getOrElse(e)
+ }
+ }
+
}
object YarnSparkHadoopUtil {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index d3acaf229cc8..9202bd892f01 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -255,7 +255,7 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
hadoopConf.set("hive.metastore.uris", "http://localhost:0")
val util = new YarnSparkHadoopUtil
assertNestedHiveException(intercept[InvocationTargetException] {
- util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice")
+ util.obtainTokenForHiveMetastoreInner(hadoopConf)
})
assertNestedHiveException(intercept[InvocationTargetException] {
util.obtainTokenForHiveMetastore(hadoopConf)