From 07923c42fcb4d210333b3882490e23f33dc4822f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 13:48:42 -0800 Subject: [PATCH 001/215] support kafka in Python --- python/pyspark/streaming/kafka.py | 71 +++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 python/pyspark/streaming/kafka.py diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py new file mode 100644 index 000000000000..f383362acfac --- /dev/null +++ b/python/pyspark/streaming/kafka.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from py4j.java_collections import MapConverter +from py4j.java_gateway import java_import, Py4JError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['KafkaUtils'] + + +class KafkaUtils(object): + + @staticmethod + def createStream(ssc, zkQuorum, groupId, topics, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + keyDecoder=None, valueDecoder=None): + """ + Create an input stream that pulls messages from a Kafka Broker. + + :param ssc: StreamingContext object + :param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..). + :param groupId: The group id for this consumer. + :param topics: Dict of (topic_name -> numPartitions) to consume. + Each partition is consumed in its own thread. + :param storageLevel: RDD storage level. + :param keyDecoder: A function used to decode key + :param valueDecoder: A function used to decode value + :return: A DStream object + """ + java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") + + if not isinstance(topics, dict): + raise TypeError("topics should be dict") + jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + try: + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, jlevel) + except Py4JError, e: + if 'call a package' in e.message: + print "No kafka package, please build it and add it into classpath:" + print " $ sbt/sbt streaming-kafka/package" + print " $ bin/submit --driver-class-path external/kafka/target/scala-2.10/" \ + "spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" + raise Exception("No kafka package") + raise e + ser = PairDeserializer(UTF8Deserializer(), UTF8Deserializer()) + stream = DStream(jstream, ssc, ser) + + if keyDecoder is not None: + stream = stream.map(lambda (k, v): (keyDecoder(k), v)) + if valueDecoder is not None: + stream = stream.mapValues(valueDecoder) + return stream \ No newline at end of file From 75d485e65b75a7a5da91a37ff42a9bb7cd82dcf6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 14:18:59 -0800 Subject: [PATCH 002/215] add mqtt --- python/pyspark/streaming/kafka.py | 2 +- python/pyspark/streaming/mqtt.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 python/pyspark/streaming/mqtt.py diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index f383362acfac..ed52f853658f 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -68,4 +68,4 @@ def createStream(ssc, zkQuorum, groupId, topics, stream = stream.map(lambda (k, v): (keyDecoder(k), v)) if valueDecoder is not None: stream = stream.mapValues(valueDecoder) - return stream \ No newline at end of file + return stream diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 000000000000..78e4cb7f1464 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from py4j.java_gateway import java_import, Py4JError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that receives messages pushed by a MQTT publisher. + + :param ssc: StreamingContext object + :param brokerUrl: Url of remote MQTT publisher + :param topic: Topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils") + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + try: + jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JError, e: + if 'call a package' in e.message: + print "No MQTT package, please build it and add it into classpath:" + print " $ sbt/sbt streaming-mqtt/package" + print " $ bin/submit --driver-class-path external/mqtt/target/scala-2.10/" \ + "spark-streaming-mqtt_2.10-1.3.0-SNAPSHOT.jar" + raise Exception("No mqtt package") + raise e + return DStream(jstream, ssc, UTF8Deserializer()) From 048dbe6c9ec4bff4eeee52c70e4e18d48d3075e0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 14:27:43 -0800 Subject: [PATCH 003/215] fix python style --- python/pyspark/streaming/kafka.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index ed52f853658f..24d62db1979b 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -52,7 +52,8 @@ def createStream(ssc, zkQuorum, groupId, topics, jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, jlevel) + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, + jlevel) except Py4JError, e: if 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" From 5697a012def1b8508d21d96ccccf2afb7d6705cf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Dec 2014 15:44:33 -0800 Subject: [PATCH 004/215] bypass decoder in scala --- python/pyspark/streaming/kafka.py | 39 +++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 24d62db1979b..6d0cb5d4c1f6 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,23 +15,26 @@ # limitations under the License. # - from py4j.java_collections import MapConverter from py4j.java_gateway import java_import, Py4JError from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, UTF8Deserializer +from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream __all__ = ['KafkaUtils'] +def utf8_decoder(s): + return s.decode('utf-8') + + class KafkaUtils(object): @staticmethod def createStream(ssc, zkQuorum, groupId, topics, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, - keyDecoder=None, valueDecoder=None): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ Create an input stream that pulls messages from a Kafka Broker. @@ -47,26 +50,32 @@ def createStream(ssc, zkQuorum, groupId, topics, """ java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") + param = { + "zookeeper.connect": zkQuorum, + "group.id": groupId, + "zookeeper.connection.timeout.ms": "10000", + } if not isinstance(topics, dict): raise TypeError("topics should be dict") jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) + jparam = MapConverter().convert(param, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + def getClassByName(name): + return ssc._jvm.org.apache.spark.util.Utils.classForName(name) + try: - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, - jlevel) + array = getClassByName("[B") + decoder = getClassByName("kafka.serializer.DefaultDecoder") + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, + jparam, jtopics, jlevel) except Py4JError, e: if 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" - print " $ bin/submit --driver-class-path external/kafka/target/scala-2.10/" \ - "spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" - raise Exception("No kafka package") + print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ + "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" raise e - ser = PairDeserializer(UTF8Deserializer(), UTF8Deserializer()) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) - - if keyDecoder is not None: - stream = stream.map(lambda (k, v): (keyDecoder(k), v)) - if valueDecoder is not None: - stream = stream.mapValues(valueDecoder) - return stream + return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) From 98c8d179d3ff264d03eabc3ddd72936d95e6e305 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Dec 2014 15:58:29 -0800 Subject: [PATCH 005/215] fix python style --- python/pyspark/streaming/kafka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 6d0cb5d4c1f6..c27699fee6b8 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -74,7 +74,7 @@ def getClassByName(name): print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ - "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" + "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) From 815de54002f9c1cfedc398e95896fa207b4a5305 Mon Sep 17 00:00:00 2001 From: YanTangZhai Date: Mon, 29 Dec 2014 11:30:54 -0800 Subject: [PATCH 006/215] [SPARK-4946] [CORE] Using AkkaUtils.askWithReply in MapOutputTracker.askTracker to reduce the chance of the communicating problem Using AkkaUtils.askWithReply in MapOutputTracker.askTracker to reduce the chance of the communicating problem Author: YanTangZhai Author: yantangzhai Closes #3785 from YanTangZhai/SPARK-4946 and squashes the following commits: 9ca6541 [yantangzhai] [SPARK-4946] [CORE] Using AkkaUtils.askWithReply in MapOutputTracker.askTracker to reduce the chance of the communicating problem e4c2c0a [YanTangZhai] Merge pull request #15 from apache/master 718afeb [YanTangZhai] Merge pull request #12 from apache/master 6e643f8 [YanTangZhai] Merge pull request #11 from apache/master e249846 [YanTangZhai] Merge pull request #10 from apache/master d26d982 [YanTangZhai] Merge pull request #9 from apache/master 76d4027 [YanTangZhai] Merge pull request #8 from apache/master 03b62b0 [YanTangZhai] Merge pull request #7 from apache/master 8a00106 [YanTangZhai] Merge pull request #6 from apache/master cbcba66 [YanTangZhai] Merge pull request #3 from apache/master cdef539 [YanTangZhai] Merge pull request #1 from apache/master --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a074ab8ece1b..6e4edc7c80d7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -76,6 +76,8 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) + private val retryAttempts = AkkaUtils.numRetries(conf) + private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) /** Set to the MapOutputTrackerActor living on the driver. */ var trackerActor: ActorRef = _ @@ -108,8 +110,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker(message: Any): Any = { try { - val future = trackerActor.ask(message)(timeout) - Await.result(future, timeout) + AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) From 8d72341ab75a7fb138b056cfb4e21db42aca55fb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 29 Dec 2014 12:05:08 -0800 Subject: [PATCH 007/215] [Minor] Fix a typo of type parameter in JavaUtils.scala In JavaUtils.scala, thare is a typo of type parameter. In addition, the type information is removed at the time of compile by erasure. This issue is really minor so I don't file in JIRA. Author: Kousuke Saruta Closes #3789 from sarutak/fix-typo-in-javautils and squashes the following commits: e20193d [Kousuke Saruta] Fixed a typo of type parameter 82bc5d9 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into fix-typo-in-javautils 99f6f63 [Kousuke Saruta] Fixed a typo of type parameter in JavaUtils.scala --- core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 86e94931300f..71b26737b8c0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -80,7 +80,7 @@ private[spark] object JavaUtils { prev match { case Some(k) => underlying match { - case mm: mutable.Map[a, _] => + case mm: mutable.Map[A, _] => mm remove k prev = None case _ => From 02b55de3dce9a1fef806be13e5cefa0f39ea2fcc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Dec 2014 13:24:26 -0800 Subject: [PATCH 008/215] [SPARK-4409][MLlib] Additional Linear Algebra Utils Addition of a very limited number of local matrix manipulation and generation methods that would be helpful in the further development for algorithms on top of BlockMatrix (SPARK-3974), such as Randomized SVD, and Multi Model Training (SPARK-1486). The proposed methods for addition are: For `Matrix` - map: maps the values in the matrix with a given function. Produces a new matrix. - update: the values in the matrix are updated with a given function. Occurs in place. Factory methods for `DenseMatrix`: - *zeros: Generate a matrix consisting of zeros - *ones: Generate a matrix consisting of ones - *eye: Generate an identity matrix - *rand: Generate a matrix consisting of i.i.d. uniform random numbers - *randn: Generate a matrix consisting of i.i.d. gaussian random numbers - *diag: Generate a diagonal matrix from a supplied vector *These methods already exist in the factory methods for `Matrices`, however for cases where we require a `DenseMatrix`, you constantly have to add `.asInstanceOf[DenseMatrix]` everywhere, which makes the code "dirtier". I propose moving these functions to factory methods for `DenseMatrix` where the putput will be a `DenseMatrix` and the factory methods for `Matrices` will call these functions directly and output a generic `Matrix`. Factory methods for `SparseMatrix`: - speye: Identity matrix in sparse format. Saves a ton of memory when dimensions are large, especially in Multi Model Training, where each row requires being multiplied by a scalar. - sprand: Generate a sparse matrix with a given density consisting of i.i.d. uniform random numbers. - sprandn: Generate a sparse matrix with a given density consisting of i.i.d. gaussian random numbers. - diag: Generate a diagonal matrix from a supplied vector, but is memory efficient, because it just stores the diagonal. Again, very helpful in Multi Model Training. Factory methods for `Matrices`: - Include all the factory methods given above, but return a generic `Matrix` rather than `SparseMatrix` or `DenseMatrix`. - horzCat: Horizontally concatenate matrices to form one larger matrix. Very useful in both Multi Model Training, and for the repartitioning of BlockMatrix. - vertCat: Vertically concatenate matrices to form one larger matrix. Very useful for the repartitioning of BlockMatrix. The names for these methods were selected from MATLAB Author: Burak Yavuz Author: Xiangrui Meng Closes #3319 from brkyvz/SPARK-4409 and squashes the following commits: b0354f6 [Burak Yavuz] [SPARK-4409] Incorporated mengxr's code 04c4829 [Burak Yavuz] Merge pull request #1 from mengxr/SPARK-4409 80cfa29 [Xiangrui Meng] minor changes ecc937a [Xiangrui Meng] update sprand 4e95e24 [Xiangrui Meng] simplify fromCOO implementation 10a63a6 [Burak Yavuz] [SPARK-4409] Fourth pass of code review f62d6c7 [Burak Yavuz] [SPARK-4409] Modified genRandMatrix 3971c93 [Burak Yavuz] [SPARK-4409] Third pass of code review 75239f8 [Burak Yavuz] [SPARK-4409] Second pass of code review e4bd0c0 [Burak Yavuz] [SPARK-4409] Modified horzcat and vertcat 65c562e [Burak Yavuz] [SPARK-4409] Hopefully fixed Java Test d8be7bc [Burak Yavuz] [SPARK-4409] Organized imports 065b531 [Burak Yavuz] [SPARK-4409] First pass after code review a8120d2 [Burak Yavuz] [SPARK-4409] Finished updates to API according to SPARK-4614 f798c82 [Burak Yavuz] [SPARK-4409] Updated API according to SPARK-4614 c75f3cd [Burak Yavuz] [SPARK-4409] Added JavaAPI Tests, and fixed a couple of bugs d662f9d [Burak Yavuz] [SPARK-4409] Modified according to remote repo 83dfe37 [Burak Yavuz] [SPARK-4409] Scalastyle error fixed a14c0da [Burak Yavuz] [SPARK-4409] Initial commit to add methods --- .../apache/spark/mllib/linalg/Matrices.scala | 570 ++++++++++++++++-- .../spark/mllib/linalg/JavaMatricesSuite.java | 163 +++++ .../spark/mllib/linalg/MatricesSuite.scala | 172 +++++- .../spark/mllib/util/TestingUtils.scala | 6 +- 4 files changed, 868 insertions(+), 43 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 327366a1a3a8..5a7281ec6dc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,9 +17,11 @@ package org.apache.spark.mllib.linalg -import java.util.{Random, Arrays} +import java.util.{Arrays, Random} -import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} +import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} + +import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} /** * Trait for a local matrix. @@ -80,6 +82,16 @@ sealed trait Matrix extends Serializable { /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() + + /** Map the values of this matrix using a function. Generates a new matrix. Performs the + * function on only the backing array. For example, an operation such as addition or + * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ + private[mllib] def map(f: Double => Double): Matrix + + /** Update all the values of this matrix using the function f. Performed in-place on the + * backing array. For example, an operation such as addition or subtraction will only be + * performed on the non-zero values in a `SparseMatrix`. */ + private[mllib] def update(f: Double => Double): Matrix } /** @@ -123,6 +135,122 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) } override def copy = new DenseMatrix(numRows, numCols, values.clone()) + + private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) + + private[mllib] def update(f: Double => Double): DenseMatrix = { + val len = values.length + var i = 0 + while (i < len) { + values(i) = f(values(i)) + i += 1 + } + this + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + def toSparse(): SparseMatrix = { + val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble + val colPtrs: Array[Int] = new Array[Int](numCols + 1) + val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt + var nnz = 0 + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + val v = values(indStart + i) + if (v != 0.0) { + rowIndices += i + spVals += v + nnz += 1 + } + i += 1 + } + j += 1 + colPtrs(j) = nnz + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) + } +} + +/** + * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. + */ +object DenseMatrix { + + /** + * Generate a `DenseMatrix` consisting of zeros. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + */ + def zeros(numRows: Int, numCols: Int): DenseMatrix = + new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + + /** + * Generate a `DenseMatrix` consisting of ones. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + */ + def ones(numRows: Int, numCols: Int): DenseMatrix = + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + + /** + * Generate an Identity Matrix in `DenseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def eye(n: Int): DenseMatrix = { + val identity = DenseMatrix.zeros(n, n) + var i = 0 + while (i < n) { + identity.update(i, i, 1.0) + i += 1 + } + identity + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param rng a random number generator + * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param rng a random number generator + * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) + } + + /** + * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix + * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * on the diagonal + */ + def diag(vector: Vector): DenseMatrix = { + val n = vector.size + val matrix = DenseMatrix.zeros(n, n) + val values = vector.toArray + var i = 0 + while (i < n) { + matrix.update(i, i, values(i)) + i += 1 + } + matrix + } } /** @@ -156,6 +284,8 @@ class SparseMatrix( require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + s"numCols: $numCols") + require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") override def toArray: Array[Double] = { val arr = new Array[Double](numRows * numCols) @@ -188,7 +318,7 @@ class SparseMatrix( private[mllib] def update(i: Int, j: Int, v: Double): Unit = { val ind = index(i, j) - if (ind == -1){ + if (ind == -1) { throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { @@ -197,6 +327,192 @@ class SparseMatrix( } override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) + + private[mllib] def map(f: Double => Double) = + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) + + private[mllib] def update(f: Double => Double): SparseMatrix = { + val len = values.length + var i = 0 + while (i < len) { + values(i) = f(values(i)) + i += 1 + } + this + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + def toDense(): DenseMatrix = { + new DenseMatrix(numRows, numCols, toArray) + } +} + +/** + * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. + */ +object SparseMatrix { + + /** + * Generate a `SparseMatrix` from Coordinate List (COO) format. Input must be an array of + * (i, j, value) tuples. Entries that have duplicate values of i and j are + * added together. Tuples where value is equal to zero will be omitted. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param entries Array of (i, j, value) tuples + * @return The corresponding `SparseMatrix` + */ + def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { + val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) + val numEntries = sortedEntries.size + if (sortedEntries.nonEmpty) { + // Since the entries are sorted by column index, we only need to check the first and the last. + for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) { + require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.") + } + } + val colPtrs = new Array[Int](numCols + 1) + val rowIndices = MArrayBuilder.make[Int] + rowIndices.sizeHint(numEntries) + val values = MArrayBuilder.make[Double] + values.sizeHint(numEntries) + var nnz = 0 + var prevCol = 0 + var prevRow = -1 + var prevVal = 0.0 + // Append a dummy entry to include the last one at the end of the loop. + (sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) => + if (v != 0) { + if (i == prevRow && j == prevCol) { + prevVal += v + } else { + if (prevVal != 0) { + require(prevRow >= 0 && prevRow < numRows, + s"Row index out of range [0, $numRows): $prevRow.") + nnz += 1 + rowIndices += prevRow + values += prevVal + } + prevRow = i + prevVal = v + while (prevCol < j) { + colPtrs(prevCol + 1) = nnz + prevCol += 1 + } + } + } + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result()) + } + + /** + * Generate an Identity Matrix in `SparseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def speye(n: Int): SparseMatrix = { + new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) + } + + /** + * Generates the skeleton of a random `SparseMatrix` with a given random number generator. + * The values of the matrix returned are undefined. + */ + private def genRandMatrix( + numRows: Int, + numCols: Int, + density: Double, + rng: Random): SparseMatrix = { + require(numRows > 0, s"numRows must be greater than 0 but got $numRows") + require(numCols > 0, s"numCols must be greater than 0 but got $numCols") + require(density >= 0.0 && density <= 1.0, + s"density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density") + val size = numRows.toLong * numCols + val expected = size * density + assert(expected < Int.MaxValue, + "The expected number of nonzeros cannot be greater than Int.MaxValue.") + val nnz = math.ceil(expected).toInt + if (density == 0.0) { + new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1), Array[Int](), Array[Double]()) + } else if (density == 1.0) { + val colPtrs = Array.tabulate(numCols + 1)(j => j * numRows) + val rowIndices = Array.tabulate(size.toInt)(idx => idx % numRows) + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](numRows * numCols)) + } else if (density < 0.34) { + // draw-by-draw, expected number of iterations is less than 1.5 * nnz + val entries = MHashSet[(Int, Int)]() + while (entries.size < nnz) { + entries += ((rng.nextInt(numRows), rng.nextInt(numCols))) + } + SparseMatrix.fromCOO(numRows, numCols, entries.map(v => (v._1, v._2, 1.0))) + } else { + // selection-rejection method + var idx = 0L + var numSelected = 0 + var j = 0 + val colPtrs = new Array[Int](numCols + 1) + val rowIndices = new Array[Int](nnz) + while (j < numCols && numSelected < nnz) { + var i = 0 + while (i < numRows && numSelected < nnz) { + if (rng.nextDouble() < 1.0 * (nnz - numSelected) / (size - idx)) { + rowIndices(numSelected) = i + numSelected += 1 + } + i += 1 + idx += 1 + } + colPtrs(j + 1) = numSelected + j += 1 + } + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](nnz)) + } + } + + /** + * Generate a `SparseMatrix` consisting of i.i.d. uniform random numbers. The number of non-zero + * elements equal the ceiling of `numRows` x `numCols` x `density` + * + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { + val mat = genRandMatrix(numRows, numCols, density, rng) + mat.update(i => rng.nextDouble()) + } + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { + val mat = genRandMatrix(numRows, numCols, density, rng) + mat.update(i => rng.nextGaussian()) + } + + /** + * Generate a diagonal matrix in `SparseMatrix` format from the supplied values. + * @param vector a `Vector` that will form the values on the diagonal of the matrix + * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero + * `values` on the diagonal + */ + def diag(vector: Vector): SparseMatrix = { + val n = vector.size + vector match { + case sVec: SparseVector => + SparseMatrix.fromCOO(n, n, sVec.indices.zip(sVec.values).map(v => (v._1, v._1, v._2))) + case dVec: DenseVector => + val entries = dVec.values.zipWithIndex + val nnzVals = entries.filter(v => v._1 != 0.0) + SparseMatrix.fromCOO(n, n, nnzVals.map(v => (v._2, v._2, v._1))) + } + } } /** @@ -256,72 +572,250 @@ object Matrices { * Generate a `DenseMatrix` consisting of zeros. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): Matrix = - new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** * Generate a `DenseMatrix` consisting of ones. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix - * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + * @return `Matrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): Matrix = - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** - * Generate an Identity Matrix in `DenseMatrix` format. + * Generate a dense Identity Matrix in `Matrix` format. * @param n number of rows and columns of the matrix - * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ - def eye(n: Int): Matrix = { - val identity = Matrices.zeros(n, n) - var i = 0 - while (i < n){ - identity.update(i, i, 1.0) - i += 1 - } - identity - } + def eye(n: Int): Matrix = DenseMatrix.eye(n) + + /** + * Generate a sparse Identity Matrix in `Matrix` format. + * @param n number of rows and columns of the matrix + * @return `Matrix` with size `n` x `n` and values of ones on the diagonal + */ + def speye(n: Int): Matrix = SparseMatrix.speye(n) /** * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator - * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int, rng: Random): Matrix = { - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) - } + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.rand(numRows, numCols, rng) + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprand(numRows, numCols, density, rng) /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @param rng a random number generator - * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int, rng: Random): Matrix = { - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) - } + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = + DenseMatrix.randn(numRows, numCols, rng) + + /** + * Generate a `SparseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @param density the desired density for the matrix + * @param rng a random number generator + * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = + SparseMatrix.sprandn(numRows, numCols, density, rng) /** * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. * @param vector a `Vector` tat will form the values on the diagonal of the matrix - * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ - def diag(vector: Vector): Matrix = { - val n = vector.size - val matrix = Matrices.eye(n) - val values = vector.toArray - var i = 0 - while (i < n) { - matrix.update(i, i, values(i)) - i += 1 + def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) + + /** + * Horizontally concatenate a sequence of matrices. The returned matrix will be in the format + * the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in + * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. + * @param matrices array of matrices + * @return a single `Matrix` composed of the matrices that were horizontally concatenated + */ + def horzcat(matrices: Array[Matrix]): Matrix = { + if (matrices.isEmpty) { + return new DenseMatrix(0, 0, Array[Double]()) + } else if (matrices.size == 1) { + return matrices(0) + } + val numRows = matrices(0).numRows + var hasSparse = false + var numCols = 0 + matrices.foreach { mat => + require(numRows == mat.numRows, "The number of rows of the matrices in this sequence, " + + "don't match!") + mat match { + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose + case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") + } + numCols += mat.numCols + } + if (!hasSparse) { + new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) + } else { + var startCol = 0 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { + case spMat: SparseMatrix => + var j = 0 + val colPtrs = spMat.colPtrs + val rowIndices = spMat.rowIndices + val values = spMat.values + val data = new Array[(Int, Int, Double)](values.length) + val nCols = spMat.numCols + while (j < nCols) { + var idx = colPtrs(j) + while (idx < colPtrs(j + 1)) { + val i = rowIndices(idx) + val v = values(idx) + data(idx) = (i, j + startCol, v) + idx += 1 + } + j += 1 + } + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + var j = 0 + val nCols = dnMat.numCols + val nRows = dnMat.numRows + val values = dnMat.values + while (j < nCols) { + var i = 0 + val indStart = j * nRows + while (i < nRows) { + val v = values(indStart + i) + if (v != 0.0) { + data.append((i, j + startCol, v)) + } + i += 1 + } + j += 1 + } + startCol += nCols + data + } + SparseMatrix.fromCOO(numRows, numCols, entries) + } + } + + /** + * Vertically concatenate a sequence of matrices. The returned matrix will be in the format + * the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in + * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. + * @param matrices array of matrices + * @return a single `Matrix` composed of the matrices that were vertically concatenated + */ + def vertcat(matrices: Array[Matrix]): Matrix = { + if (matrices.isEmpty) { + return new DenseMatrix(0, 0, Array[Double]()) + } else if (matrices.size == 1) { + return matrices(0) + } + val numCols = matrices(0).numCols + var hasSparse = false + var numRows = 0 + matrices.foreach { mat => + require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + + "don't match!") + mat match { + case sparse: SparseMatrix => + hasSparse = true + case dense: DenseMatrix => + case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") + } + numRows += mat.numRows + + } + if (!hasSparse) { + val allValues = new Array[Double](numRows * numCols) + var startRow = 0 + matrices.foreach { mat => + var j = 0 + val nRows = mat.numRows + val values = mat.toArray + while (j < numCols) { + var i = 0 + val indStart = j * numRows + startRow + val subMatStart = j * nRows + while (i < nRows) { + allValues(indStart + i) = values(subMatStart + i) + i += 1 + } + j += 1 + } + startRow += nRows + } + new DenseMatrix(numRows, numCols, allValues) + } else { + var startRow = 0 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { + case spMat: SparseMatrix => + var j = 0 + val colPtrs = spMat.colPtrs + val rowIndices = spMat.rowIndices + val values = spMat.values + val data = new Array[(Int, Int, Double)](values.length) + while (j < numCols) { + var idx = colPtrs(j) + while (idx < colPtrs(j + 1)) { + val i = rowIndices(idx) + val v = values(idx) + data(idx) = (i + startRow, j, v) + idx += 1 + } + j += 1 + } + startRow += spMat.numRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + var j = 0 + val nCols = dnMat.numCols + val nRows = dnMat.numRows + val values = dnMat.values + while (j < nCols) { + var i = 0 + val indStart = j * nRows + while (i < nRows) { + val v = values(indStart + i) + if (v != 0.0) { + data.append((i + startRow, j, v)) + } + i += 1 + } + j += 1 + } + startRow += nRows + data + } + SparseMatrix.fromCOO(numRows, numCols, entries) } - matrix } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java new file mode 100644 index 000000000000..704d484d0b58 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg; + +import static org.junit.Assert.*; +import org.junit.Test; + +import java.io.Serializable; +import java.util.Random; + +public class JavaMatricesSuite implements Serializable { + + @Test + public void randMatrixConstruction() { + Random rng = new Random(24); + Matrix r = Matrices.rand(3, 4, rng); + rng.setSeed(24); + DenseMatrix dr = DenseMatrix.rand(3, 4, rng); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + + rng.setSeed(24); + Matrix rn = Matrices.randn(3, 4, rng); + rng.setSeed(24); + DenseMatrix drn = DenseMatrix.randn(3, 4, rng); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + + rng.setSeed(24); + Matrix s = Matrices.sprand(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + + rng.setSeed(24); + Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } + + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } + + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.diag(v); + SparseMatrix ss = SparseMatrix.diag(sv); + + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assert(s.values().length == 2); + assert(ss.values().length == 2); + assert(s.colPtrs().length == 4); + assert(ss.colPtrs().length == 4); + } + + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); + + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } + + @Test + public void sparseDenseConversion() { + int m = 3; + int n = 2; + double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; + double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; + int[] colPtrs = new int[]{0, 2, 4}; + int[] rowIndices = new int[]{0, 1, 1, 2}; + + SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); + DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); + + SparseMatrix spMat2 = deMat1.toSparse(); + DenseMatrix deMat2 = spMat1.toDense(); + + assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); + assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); + } + + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; + + Random rng = new Random(42); + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); + rng.setSeed(42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); + + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + + assert(deHorz1.numRows() == 3); + assert(deHorz2.numRows() == 3); + assert(deHorz3.numRows() == 3); + assert(spHorz.numRows() == 3); + assert(deHorz1.numCols() == 5); + assert(deHorz2.numCols() == 5); + assert(deHorz3.numCols() == 5); + assert(spHorz.numCols() == 5); + + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + + assert(deVert1.numRows() == 5); + assert(deVert2.numRows() == 5); + assert(deVert3.numRows() == 5); + assert(spVert.numRows() == 5); + assert(deVert1.numCols() == 2); + assert(deVert2.numCols() == 2); + assert(deVert3.numCols() == 2); + assert(spVert.numCols() == 2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 322a0e924291..a35d0fe389fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -43,9 +43,9 @@ class MatricesSuite extends FunSuite { test("sparse matrix construction") { val m = 3 - val n = 2 + val n = 4 val values = Array(1.0, 2.0, 4.0, 5.0) - val colPtrs = Array(0, 2, 4) + val colPtrs = Array(0, 2, 2, 4, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(m, n, colPtrs, rowIndices, values).asInstanceOf[SparseMatrix] assert(mat.numRows === m) @@ -53,6 +53,13 @@ class MatricesSuite extends FunSuite { assert(mat.values.eq(values), "should not copy data") assert(mat.colPtrs.eq(colPtrs), "should not copy data") assert(mat.rowIndices.eq(rowIndices), "should not copy data") + + val entries: Array[(Int, Int, Double)] = Array((2, 2, 3.0), (1, 0, 1.0), (2, 0, 2.0), + (1, 2, 2.0), (2, 2, 2.0), (1, 2, 2.0), (0, 0, 0.0)) + + val mat2 = SparseMatrix.fromCOO(m, n, entries) + assert(mat.toBreeze === mat2.toBreeze) + assert(mat2.values.length == 4) } test("sparse matrix construction with wrong number of elements") { @@ -117,6 +124,142 @@ class MatricesSuite extends FunSuite { assert(sparseMat.values(2) === 10.0) } + test("toSparse, toDense") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + + val spMat2 = deMat1.toSparse() + val deMat2 = spMat1.toDense() + + assert(spMat1.toBreeze === spMat2.toBreeze) + assert(deMat1.toBreeze === deMat2.toBreeze) + } + + test("map, update") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + val deMat2 = deMat1.map(_ * 2) + val spMat2 = spMat1.map(_ * 2) + deMat1.update(_ * 2) + spMat1.update(_ * 2) + + assert(spMat1.toArray === spMat2.toArray) + assert(deMat1.toArray === deMat2.toArray) + } + + test("horzcat, vertcat, eye, speye") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val deMat1 = new DenseMatrix(m, n, allValues) + val deMat2 = Matrices.eye(3) + val spMat2 = Matrices.speye(3) + val deMat3 = Matrices.eye(2) + val spMat3 = Matrices.speye(2) + + val spHorz = Matrices.horzcat(Array(spMat1, spMat2)) + val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) + val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) + val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) + + val deHorz2 = Matrices.horzcat(Array[Matrix]()) + + assert(deHorz1.numRows === 3) + assert(spHorz2.numRows === 3) + assert(spHorz3.numRows === 3) + assert(spHorz.numRows === 3) + assert(deHorz1.numCols === 5) + assert(spHorz2.numCols === 5) + assert(spHorz3.numCols === 5) + assert(spHorz.numCols === 5) + assert(deHorz2.numRows === 0) + assert(deHorz2.numCols === 0) + assert(deHorz2.toArray.length === 0) + + assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) + assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(spHorz(0, 0) === 1.0) + assert(spHorz(2, 1) === 5.0) + assert(spHorz(0, 2) === 1.0) + assert(spHorz(1, 2) === 0.0) + assert(spHorz(1, 3) === 1.0) + assert(spHorz(2, 4) === 1.0) + assert(spHorz(1, 4) === 0.0) + assert(deHorz1(0, 0) === 1.0) + assert(deHorz1(2, 1) === 5.0) + assert(deHorz1(0, 2) === 1.0) + assert(deHorz1(1, 2) == 0.0) + assert(deHorz1(1, 3) === 1.0) + assert(deHorz1(2, 4) === 1.0) + assert(deHorz1(1, 4) === 0.0) + + intercept[IllegalArgumentException] { + Matrices.horzcat(Array(spMat1, spMat3)) + } + + intercept[IllegalArgumentException] { + Matrices.horzcat(Array(deMat1, spMat3)) + } + + val spVert = Matrices.vertcat(Array(spMat1, spMat3)) + val deVert1 = Matrices.vertcat(Array(deMat1, deMat3)) + val spVert2 = Matrices.vertcat(Array(spMat1, deMat3)) + val spVert3 = Matrices.vertcat(Array(deMat1, spMat3)) + val deVert2 = Matrices.vertcat(Array[Matrix]()) + + assert(deVert1.numRows === 5) + assert(spVert2.numRows === 5) + assert(spVert3.numRows === 5) + assert(spVert.numRows === 5) + assert(deVert1.numCols === 2) + assert(spVert2.numCols === 2) + assert(spVert3.numCols === 2) + assert(spVert.numCols === 2) + assert(deVert2.numRows === 0) + assert(deVert2.numCols === 0) + assert(deVert2.toArray.length === 0) + + assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) + assert(spVert2.toBreeze === spVert3.toBreeze) + assert(spVert(0, 0) === 1.0) + assert(spVert(2, 1) === 5.0) + assert(spVert(3, 0) === 1.0) + assert(spVert(3, 1) === 0.0) + assert(spVert(4, 1) === 1.0) + assert(deVert1(0, 0) === 1.0) + assert(deVert1(2, 1) === 5.0) + assert(deVert1(3, 0) === 1.0) + assert(deVert1(3, 1) === 0.0) + assert(deVert1(4, 1) === 1.0) + + intercept[IllegalArgumentException] { + Matrices.vertcat(Array(spMat1, spMat2)) + } + + intercept[IllegalArgumentException] { + Matrices.vertcat(Array(deMat1, spMat2)) + } + } + test("zeros") { val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] assert(mat.numRows === 2) @@ -162,4 +305,29 @@ class MatricesSuite extends FunSuite { assert(mat.numCols === 2) assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) } + + test("sprand") { + val rng = mock[Random] + when(rng.nextInt(4)).thenReturn(0, 1, 1, 3, 2, 2, 0, 1, 3, 0) + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0) + val mat = SparseMatrix.sprand(4, 4, 0.25, rng) + assert(mat.numRows === 4) + assert(mat.numCols === 4) + assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + val mat2 = SparseMatrix.sprand(2, 3, 1.0, rng) + assert(mat2.rowIndices.toSeq === Seq(0, 1, 0, 1, 0, 1)) + assert(mat2.colPtrs.toSeq === Seq(0, 2, 4, 6)) + } + + test("sprandn") { + val rng = mock[Random] + when(rng.nextInt(4)).thenReturn(0, 1, 1, 3, 2, 2, 0, 1, 3, 0) + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = SparseMatrix.sprandn(4, 4, 0.25, rng) + assert(mat.numRows === 4) + assert(mat.numCols === 4) + assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1)) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 30b906aaa3ba..e957fa5d25f4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -178,17 +178,17 @@ object TestingUtils { implicit class MatrixWithAlmostEquals(val x: Matrix) { /** - * When the difference of two vectors are within eps, returns true; otherwise, returns false. + * When the difference of two matrices are within eps, returns true; otherwise, returns false. */ def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps) /** - * When the difference of two vectors are within eps, returns false; otherwise, returns true. + * When the difference of two matrices are within eps, returns false; otherwise, returns true. */ def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps) /** - * Throws exception when the difference of two vectors are NOT within eps; + * Throws exception when the difference of two matrices are NOT within eps; * otherwise, returns true. */ def ~==(r: CompareMatrixRightSide): Boolean = { From 9bc0df6804f241aff24520d9c6ec54d9b11f5785 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Mon, 29 Dec 2014 13:49:45 -0800 Subject: [PATCH 009/215] SPARK-4968: takeOrdered to skip reduce step in case mappers return no partitions takeOrdered should skip reduce step in case mapped RDDs have no partitions. This prevents the mentioned exception : 4. run query SELECT * FROM testTable WHERE market = 'market2' ORDER BY End_Time DESC LIMIT 100; Error trace java.lang.UnsupportedOperationException: empty collection at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:863) at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:863) at scala.Option.getOrElse(Option.scala:120) at org.apache.spark.rdd.RDD.reduce(RDD.scala:863) at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1136) Author: Yash Datta Closes #3830 from saucam/fix_takeorder and squashes the following commits: 5974d10 [Yash Datta] SPARK-4968: takeOrdered to skip reduce step in case mappers return no partitions --- .../src/main/scala/org/apache/spark/rdd/RDD.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f47c2d1fcdcc..5118e2b91112 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1146,15 +1146,20 @@ abstract class RDD[T: ClassTag]( if (num == 0) { Array.empty } else { - mapPartitions { items => + val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) queue ++= util.collection.Utils.takeOrdered(items, num)(ord) Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord) + } + if (mapRDDs.partitions.size == 0) { + Array.empty + } else { + mapRDDs.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord) + } } } From 6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Mon, 29 Dec 2014 15:29:15 -0800 Subject: [PATCH 010/215] SPARK-4156 [MLLIB] EM algorithm for GMMs Implementation of Expectation-Maximization for Gaussian Mixture Models. This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license. Author: Travis Galoppo Author: Travis Galoppo Author: tgaloppo Author: FlytxtRnD Closes #3022 from tgaloppo/master and squashes the following commits: aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib] 709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM 9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class b97fe00 [Travis Galoppo] Minor fixes and tweaks. 1de73f3 [Travis Galoppo] Removed redundant array from array creation 578c2d1 [Travis Galoppo] Removed unused import 227ad66 [Travis Galoppo] Moved prediction methods into model class. 308c8ad [Travis Galoppo] Numerous changes to improve code cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate 20ebca1 [Travis Galoppo] Removed unusued code 42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite. 8b633f3 [Travis Galoppo] Style issue 9be2534 [Travis Galoppo] Style issue d695034 [Travis Galoppo] Fixed style issues c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark Adds predict() method 2df336b [Travis Galoppo] Fixed style issue b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values 97044cf [Travis Galoppo] Fixed style issues dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments 9770261 [Travis Galoppo] Corrected a variety of style and naming issues. 8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters. 676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy. 86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' 719d8cc [Travis Galoppo] Added scala test suite with basic test c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance 5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' c15405c [Travis Galoppo] SPARK-4156 --- .../spark/examples/mllib/DenseGmmEM.scala | 67 +++++ .../mllib/clustering/GaussianMixtureEM.scala | 241 ++++++++++++++++++ .../clustering/GaussianMixtureModel.scala | 91 +++++++ .../stat/impl/MultivariateGaussian.scala | 39 +++ .../org/apache/spark/mllib/util/MLUtils.scala | 2 +- .../GMMExpectationMaximizationSuite.scala | 78 ++++++ 6 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala new file mode 100644 index 000000000000..948c350953e2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.linalg.Vectors + +/** + * An example Gaussian Mixture Model EM app. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DenseGmmEM { + def main(args: Array[String]): Unit = { + if (args.length < 3) { + println("usage: DenseGmmEM [maxIterations]") + } else { + val maxIterations = if (args.length > 3) args(3).toInt else 100 + run(args(0), args(1).toInt, args(2).toDouble, maxIterations) + } + } + + private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { + val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") + val ctx = new SparkContext(conf) + + val data = ctx.textFile(inputFile).map { line => + Vectors.dense(line.trim.split(' ').map(_.toDouble)) + }.cache() + + val clusters = new GaussianMixtureEM() + .setK(k) + .setConvergenceTol(convergenceTol) + .setMaxIterations(maxIterations) + .run(data) + + for (i <- 0 until clusters.k) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (clusters.weight(i), clusters.mu(i), clusters.sigma(i))) + } + + println("Cluster labels (first <= 100):") + val clusterLabels = clusters.predict(data) + clusterLabels.take(100).foreach { x => + print(" " + x) + } + println() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala new file mode 100644 index 000000000000..bdf984aee4da --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable.IndexedSeq + +import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * This class performs expectation maximization for multivariate Gaussian + * Mixture Models (GMMs). A GMM represents a composite distribution of + * independent Gaussian distributions with associated "mixing" weights + * specifying each's contribution to the composite. + * + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * less than convergenceTol, or until it has reached the max number of iterations. + * While this process is generally guaranteed to converge, it is not guaranteed + * to find a global optimum. + * + * @param k The number of independent Gaussians in the mixture model + * @param convergenceTol The maximum change in log-likelihood at which convergence + * is considered to have occurred. + * @param maxIterations The maximum number of iterations to perform + */ +class GaussianMixtureEM private ( + private var k: Int, + private var convergenceTol: Double, + private var maxIterations: Int) extends Serializable { + + /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ + def this() = this(2, 0.01, 100) + + // number of samples per cluster to use when initializing Gaussians + private val nSamples = 5 + + // an initializing GMM can be provided rather than using the + // default random starting point + private var initialModel: Option[GaussianMixtureModel] = None + + /** Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + */ + def setInitialModel(model: GaussianMixtureModel): this.type = { + if (model.k == k) { + initialModel = Some(model) + } else { + throw new IllegalArgumentException("mismatched cluster count (model.k != k)") + } + this + } + + /** Return the user supplied initial GMM, if supplied */ + def getInitialModel: Option[GaussianMixtureModel] = initialModel + + /** Set the number of Gaussians in the mixture model. Default: 2 */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Return the number of Gaussians in the mixture model */ + def getK: Int = k + + /** Set the maximum number of iterations to run. Default: 100 */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Return the maximum number of iterations to run */ + def getMaxIterations: Int = maxIterations + + /** + * Set the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** Return the largest change in log-likelihood at which convergence is + * considered to have occurred. + */ + def getConvergenceTol: Double = convergenceTol + + /** Perform expectation maximization */ + def run(data: RDD[Vector]): GaussianMixtureModel = { + val sc = data.sparkContext + + // we will operate on the data as breeze data + val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() + + // Get length of the input vectors + val d = breezeData.first.length + + // Determine initial weights and corresponding Gaussians. + // If the user supplied an initial GMM, we use those values, otherwise + // we start with uniform weights, a random mean from the data, and + // diagonal covariance matrices using component variances + // derived from the samples + val (weights, gaussians) = initialModel match { + case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => + new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) + }) + + case None => { + val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + val slice = samples.view(i * nSamples, (i + 1) * nSamples) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + }) + } + } + + var llh = Double.MinValue // current log-likelihood + var llhp = 0.0 // previous log-likelihood + + var iter = 0 + while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) { + // create and broadcast curried cluster contribution function + val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) + + // aggregate the cluster contribution for all sample points + val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) + + // Create new distributions based on the partial assignments + // (often referred to as the "M" step in literature) + val sumWeights = sums.weights.sum + var i = 0 + while (i < k) { + val mu = sums.means(i) / sums.weights(i) + val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr + weights(i) = sums.weights(i) / sumWeights + gaussians(i) = new MultivariateGaussian(mu, sigma) + i = i + 1 + } + + llhp = llh // current becomes previous + llh = sums.logLikelihood // this is the freshly computed log-likelihood + iter += 1 + } + + // Need to convert the breeze matrices to MLlib matrices + val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) } + val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) } + new GaussianMixtureModel(weights, means, sigmas) + } + + /** Average of dense breeze vectors */ + private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = { + val v = BreezeVector.zeros[Double](x(0).length) + x.foreach(xi => v += xi) + v / x.length.toDouble + } + + /** + * Construct matrix where diagonal entries are element-wise + * variance of input vectors (computes biased variance) + */ + private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = { + val mu = vectorMean(x) + val ss = BreezeVector.zeros[Double](x(0).length) + x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u) + diag(ss / x.length.toDouble) + } +} + +// companion class to provide zero constructor for ExpectationSum +private object ExpectationSum { + def zero(k: Int, d: Int): ExpectationSum = { + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + } + + // compute cluster contributions for each input point + // (U, T) => U for aggregation + def add( + weights: Array[Double], + dists: Array[MultivariateGaussian]) + (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) + } + val pSum = p.sum + sums.logLikelihood += math.log(pSum) + val xxt = x * new Transpose(x) + var i = 0 + while (i < sums.k) { + p(i) /= pSum + sums.weights(i) += p(i) + sums.means(i) += x * p(i) + sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr + i = i + 1 + } + sums + } +} + +// Aggregation class for partial expectation results +private class ExpectationSum( + var logLikelihood: Double, + val weights: Array[Double], + val means: Array[BreezeVector[Double]], + val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { + + val k = weights.length + + def +=(x: ExpectationSum): ExpectationSum = { + var i = 0 + while (i < k) { + weights(i) += x.weights(i) + means(i) += x.means(i) + sigmas(i) += x.sigmas(i) + i = i + 1 + } + logLikelihood += x.logLikelihood + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala new file mode 100644 index 000000000000..11a110db1f7c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseVector => BreezeVector} + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.util.MLUtils + +/** + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * + * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is + * the weight for Gaussian i, and weight.sum == 1 + * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i + * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the + * covariance matrix for Gaussian i + */ +class GaussianMixtureModel( + val weight: Array[Double], + val mu: Array[Vector], + val sigma: Array[Matrix]) extends Serializable { + + /** Number of gaussians in mixture */ + def k: Int = weight.length + + /** Maps given points to their cluster indices. */ + def predict(points: RDD[Vector]): RDD[Int] = { + val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k) + responsibilityMatrix.map(r => r.indexOf(r.max)) + } + + /** + * Given the input vectors, return the membership value of each vector + * to all mixture components. + */ + def predictMembership( + points: RDD[Vector], + mu: Array[Vector], + sigma: Array[Matrix], + weight: Array[Double], + k: Int): RDD[Array[Double]] = { + val sc = points.sparkContext + val dists = sc.broadcast { + (0 until k).map { i => + new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix) + }.toArray + } + val weights = sc.broadcast(weight) + points.map { x => + computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k) + } + } + + /** + * Compute the partial assignments for each vector + */ + private def computeSoftAssignments( + pt: BreezeVector[Double], + dists: Array[MultivariateGaussian], + weights: Array[Double], + k: Int): Array[Double] = { + val p = weights.zip(dists).map { + case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) + } + val pSum = p.sum + for (i <- 0 until k) { + p(i) /= pSum + } + p + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala new file mode 100644 index 000000000000..2eab5d277827 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.impl + +import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv} + +/** + * Utility class to implement the density function for multivariate Gaussian distribution. + * Breeze provides this functionality, but it requires the Apache Commons Math library, + * so this class is here so-as to not introduce a new dependency in Spark. + */ +private[mllib] class MultivariateGaussian( + val mu: DBV[Double], + val sigma: DBM[Double]) extends Serializable { + private val sigmaInv2 = pinv(sigma) * -0.5 + private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5) + + /** Returns density of this multivariate Gaussian at given point, x */ + def pdf(x: DBV[Double]): Double = { + val delta = x - mu + val deltaTranspose = new Transpose(delta) + U * math.exp(deltaTranspose * sigmaInv2 * delta) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b0d05ae33e1b..1d07b5dab826 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.dstream.DStream */ object MLUtils { - private[util] lazy val EPSILON = { + private[mllib] lazy val EPSILON = { var eps = 1.0 while ((1.0 + (eps / 2.0)) != 1.0) { eps /= 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala new file mode 100644 index 000000000000..23feb82874b7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { + test("single cluster") { + val data = sc.parallelize(Array( + Vectors.dense(6.0, 9.0), + Vectors.dense(5.0, 10.0), + Vectors.dense(4.0, 11.0) + )) + + // expectations + val Ew = 1.0 + val Emu = Vectors.dense(5.0, 10.0) + val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) + + val gmm = new GaussianMixtureEM().setK(1).run(data) + + assert(gmm.weight(0) ~== Ew absTol 1E-5) + assert(gmm.mu(0) ~== Emu absTol 1E-5) + assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + } + + test("two clusters") { + val data = sc.parallelize(Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + )) + + // we set an initial gaussian to induce expected results + val initialGmm = new GaussianMixtureModel( + Array(0.5, 0.5), + Array(Vectors.dense(-1.0), Vectors.dense(1.0)), + Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0))) + ) + + val Ew = Array(1.0 / 3.0, 2.0 / 3.0) + val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) + val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) + + val gmm = new GaussianMixtureEM() + .setK(2) + .setInitialModel(initialGmm) + .run(data) + + assert(gmm.weight(0) ~== Ew(0) absTol 1E-3) + assert(gmm.weight(1) ~== Ew(1) absTol 1E-3) + assert(gmm.mu(0) ~== Emu(0) absTol 1E-3) + assert(gmm.mu(1) ~== Emu(1) absTol 1E-3) + assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3) + assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3) + } +} From 343db392b58fb33a3e4bc6fda1da69aaf686b5a9 Mon Sep 17 00:00:00 2001 From: ganonp Date: Mon, 29 Dec 2014 15:31:19 -0800 Subject: [PATCH 011/215] Added setMinCount to Word2Vec.scala Wanted to customize the private minCount variable in the Word2Vec class. Added a method to do so. Author: ganonp Closes #3693 from ganonp/my-custom-spark and squashes the following commits: ad534f2 [ganonp] made norm method public 5110a6f [ganonp] Reorganized 854958b [ganonp] Fixed Indentation for setMinCount 12ed8f9 [ganonp] Update Word2Vec.scala 76bdf5a [ganonp] Update Word2Vec.scala ffb88bb [ganonp] Update Word2Vec.scala 5eb9100 [ganonp] Added setMinCount to Word2Vec.scala --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 15 +++++++++++---- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7960f3cab576..d25a7cd5b439 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging { private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() - + private var minCount = 5 + /** * Sets vector size (default: 100). */ @@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + * model's vocabulary (default: 5). + */ + def setMinCount(minCount: Int): this.type = { + this.minCount = minCount + this + } + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private val window = 5 - /** minimum frequency to consider a vocabulary word */ - private val minCount = 5 - private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 47d1a76fa361..01f3f9057714 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -268,7 +268,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ - private[spark] def norm(vector: Vector, p: Double): Double = { + def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) val values = vector match { case dv: DenseVector => dv.values From 040d6f2d13b132b3ef2a1e4f12f9f0e781c5a0b8 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 29 Dec 2014 17:17:12 -0800 Subject: [PATCH 012/215] [SPARK-4972][MLlib] Updated the scala doc for lasso and ridge regression for the change of LeastSquaresGradient In #SPARK-4907, we added factor of 2 into the LeastSquaresGradient. We updated the scala doc for lasso and ridge regression here. Author: DB Tsai Closes #3808 from dbtsai/doc and squashes the following commits: ec3c989 [DB Tsai] first commit --- .../main/scala/org/apache/spark/mllib/regression/Lasso.scala | 2 +- .../org/apache/spark/mllib/regression/RidgeRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index f9791c657178..8ecd5c6ad93c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -45,7 +45,7 @@ class LassoModel ( /** * Train a regression model with L1-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + regParam ||weights||_1 + * f(weights) = 1/2n ||A weights-y||^2 + regParam ||weights||_1 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index c8cad773f5ef..076ba35051c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -45,7 +45,7 @@ class RidgeRegressionModel ( /** * Train a regression model with L2-regularization using Stochastic Gradient Descent. * This solves the l1-regularized least squares regression formulation - * f(weights) = 1/n ||A weights-y||^2 + regParam/2 ||weights||^2 + * f(weights) = 1/2n ||A weights-y||^2 + regParam/2 ||weights||^2 * Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with * its corresponding right hand side label y. * See also the documentation for the precise formulation. From 9077e721cd36adfecd50cbd1fd7735d28e5be8b5 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 30 Dec 2014 09:19:47 -0800 Subject: [PATCH 013/215] [SPARK-4920][UI] add version on master and worker page for standalone mode Author: Zhang, Liye Closes #3769 from liyezhang556520/spark-4920_WebVersion and squashes the following commits: 3bb7e0d [Zhang, Liye] add version on master and worker page --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 7486cb6b1bbc..b5022fe853c4 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -234,8 +234,9 @@ private[spark] object UIUtils extends Logging {

- + + {org.apache.spark.SPARK_VERSION} {title}

From efa80a531ecd485f6cf0cdc24ffa42ba17eea46d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 30 Dec 2014 09:29:52 -0800 Subject: [PATCH 014/215] [SPARK-4882] Register PythonBroadcast with Kryo so that PySpark works with KryoSerializer This PR fixes an issue where PySpark broadcast variables caused NullPointerExceptions if KryoSerializer was used. The fix is to register PythonBroadcast with Kryo so that it's deserialized with a KryoJavaSerializer. Author: Josh Rosen Closes #3831 from JoshRosen/SPARK-4882 and squashes the following commits: 0466c7a [Josh Rosen] Register PythonBroadcast with Kryo. d5b409f [Josh Rosen] Enable registrationRequired, which would have caught this bug. 069d8a7 [Josh Rosen] Add failing test for SPARK-4882 --- .../spark/serializer/KryoSerializer.scala | 2 + .../api/python/PythonBroadcastSuite.scala | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 621a951c27d0..d2947dcea4f7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ +import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus @@ -90,6 +91,7 @@ class KryoSerializer(conf: SparkConf) // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) + kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala new file mode 100644 index 000000000000..8959a843dbd7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import scala.io.Source + +import java.io.{PrintWriter, File} + +import org.scalatest.{Matchers, FunSuite} + +import org.apache.spark.{SharedSparkContext, SparkConf} +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.Utils + +// This test suite uses SharedSparkContext because we need a SparkEnv in order to deserialize +// a PythonBroadcast: +class PythonBroadcastSuite extends FunSuite with Matchers with SharedSparkContext { + test("PythonBroadcast can be serialized with Kryo (SPARK-4882)") { + val tempDir = Utils.createTempDir() + val broadcastedString = "Hello, world!" + def assertBroadcastIsValid(broadcast: PythonBroadcast): Unit = { + val source = Source.fromFile(broadcast.path) + val contents = source.mkString + source.close() + contents should be (broadcastedString) + } + try { + val broadcastDataFile: File = { + val file = new File(tempDir, "broadcastData") + val printWriter = new PrintWriter(file) + printWriter.write(broadcastedString) + printWriter.close() + file + } + val broadcast = new PythonBroadcast(broadcastDataFile.getAbsolutePath) + assertBroadcastIsValid(broadcast) + val conf = new SparkConf().set("spark.kryo.registrationRequired", "true") + val deserializedBroadcast = + Utils.clone[PythonBroadcast](broadcast, new KryoSerializer(conf).newInstance()) + assertBroadcastIsValid(deserializedBroadcast) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} From 480bd1d2edd1de06af607b0cf3ff3c0b16089add Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 30 Dec 2014 11:24:46 -0800 Subject: [PATCH 015/215] [SPARK-4908][SQL] Prevent multiple concurrent hive native commands This is just a quick fix that locks when calling `runHive`. If we can find a way to avoid the error without a global lock that would be better. Author: Michael Armbrust Closes #3834 from marmbrus/hiveConcurrency and squashes the following commits: bf25300 [Michael Armbrust] prevent multiple concurrent hive native commands --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 56fe27a77b83..982e0593fcfd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -284,7 +284,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Execute the command using Hive and return the results as a sequence. Each element * in the sequence is one row. */ - protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = synchronized { try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4d81acc753a2..fb6da33e88ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -56,6 +56,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } + test("SPARK-4908: concurent hive native commands") { + (1 to 100).par.map { _ => + sql("USE default") + sql("SHOW TABLES") + } + } + createQueryTest("constant object inspector for generic udf", """SELECT named_struct( lower("AA"), "10", From 94d60b7021960dc10d98039dbc6ad7193e8557f5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 30 Dec 2014 11:29:13 -0800 Subject: [PATCH 016/215] [SQL] enable view test This is a follow up of #3396 , just add a test to white list. Author: Daoyuan Wang Closes #3826 from adrian-wang/viewtest and squashes the following commits: f105f68 [Daoyuan Wang] enable view test --- .../execution/HiveCompatibilitySuite.scala | 3 +- ...anslate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 | 0 ...anslate-1-3896ae0e680a5fdc01833533b11c07bb | 0 ...nslate-10-7016e1e3a4248564f3d08cddad7ae116 | 0 ...nslate-11-e27c6a59a833dcbc2e5cdb7ff7972828 | 0 ...anslate-2-6b4caec6d7e3a91e61720bbd6b7697f0 | 0 ...anslate-3-30dc3e80e3873af5115e4f5e39078a13 | 27 ++++++++++++++++ ...anslate-4-cefb7530126f9e60cb4a29441d578f23 | 0 ...anslate-5-856ea995681b18a543dc0e53b8b43a8e | 32 +++++++++++++++++++ ...anslate-6-a14cfe3eff322066e61023ec06c7735d | 0 ...anslate-7-e947bf2dacc907825df154a4131a3fcc | 0 ...anslate-8-b1a99b0beffb0b298aec9233ecc0707f | 0 ...anslate-9-fc0dc39c4796d917685e0797bc4a9786 | 0 13 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-1-3896ae0e680a5fdc01833533b11c07bb create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-10-7016e1e3a4248564f3d08cddad7ae116 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-11-e27c6a59a833dcbc2e5cdb7ff7972828 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-2-6b4caec6d7e3a91e61720bbd6b7697f0 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-4-cefb7530126f9e60cb4a29441d578f23 create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-6-a14cfe3eff322066e61023ec06c7735d create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-7-e947bf2dacc907825df154a4131a3fcc create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-8-b1a99b0beffb0b298aec9233ecc0707f create mode 100644 sql/hive/src/test/resources/golden/create_view_translate-9-fc0dc39c4796d917685e0797bc4a9786 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 1e44dd239458..23283fd3fe6b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -101,6 +101,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "describe_comment_nonascii", "create_merge_compressed", + "create_view", "create_view_partitioned", "database_location", "database_properties", @@ -110,7 +111,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Weird DDL differences result in failures on jenkins. "create_like2", - "create_view_translate", "partitions_json", // This test is totally fine except that it includes wrong queries and expects errors, but error @@ -349,6 +349,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "create_nested_type", "create_skewed_table1", "create_struct_table", + "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", diff --git a/sql/hive/src/test/resources/golden/create_view_translate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 b/sql/hive/src/test/resources/golden/create_view_translate-0-dc7fc9ce5109ef459ee84ccfbb12d2c0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-1-3896ae0e680a5fdc01833533b11c07bb b/sql/hive/src/test/resources/golden/create_view_translate-1-3896ae0e680a5fdc01833533b11c07bb new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-10-7016e1e3a4248564f3d08cddad7ae116 b/sql/hive/src/test/resources/golden/create_view_translate-10-7016e1e3a4248564f3d08cddad7ae116 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-11-e27c6a59a833dcbc2e5cdb7ff7972828 b/sql/hive/src/test/resources/golden/create_view_translate-11-e27c6a59a833dcbc2e5cdb7ff7972828 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-2-6b4caec6d7e3a91e61720bbd6b7697f0 b/sql/hive/src/test/resources/golden/create_view_translate-2-6b4caec6d7e3a91e61720bbd6b7697f0 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 b/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 new file mode 100644 index 000000000000..cec5f77033aa --- /dev/null +++ b/sql/hive/src/test/resources/golden/create_view_translate-3-30dc3e80e3873af5115e4f5e39078a13 @@ -0,0 +1,27 @@ +# col_name data_type comment + +key string + +# Detailed Table Information +Database: default +Owner: animal +CreateTime: Mon Dec 29 00:57:55 PST 2014 +LastAccessTime: UNKNOWN +Protect Mode: None +Retention: 0 +Table Type: VIRTUAL_VIEW +Table Parameters: + transient_lastDdlTime 1419843475 + +# Storage Information +SerDe Library: null +InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat +OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat +Compressed: No +Num Buckets: -1 +Bucket Columns: [] +Sort Columns: [] + +# View Information +View Original Text: select cast(key as string) from src +View Expanded Text: select cast(`src`.`key` as string) from `default`.`src` diff --git a/sql/hive/src/test/resources/golden/create_view_translate-4-cefb7530126f9e60cb4a29441d578f23 b/sql/hive/src/test/resources/golden/create_view_translate-4-cefb7530126f9e60cb4a29441d578f23 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e b/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e new file mode 100644 index 000000000000..bf582fc0964a --- /dev/null +++ b/sql/hive/src/test/resources/golden/create_view_translate-5-856ea995681b18a543dc0e53b8b43a8e @@ -0,0 +1,32 @@ +# col_name data_type comment + +key int +value string + +# Detailed Table Information +Database: default +Owner: animal +CreateTime: Mon Dec 29 00:57:55 PST 2014 +LastAccessTime: UNKNOWN +Protect Mode: None +Retention: 0 +Table Type: VIRTUAL_VIEW +Table Parameters: + transient_lastDdlTime 1419843475 + +# Storage Information +SerDe Library: null +InputFormat: org.apache.hadoop.mapred.SequenceFileInputFormat +OutputFormat: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat +Compressed: No +Num Buckets: -1 +Bucket Columns: [] +Sort Columns: [] + +# View Information +View Original Text: select key, value from ( + select key, value from src +) a +View Expanded Text: select key, value from ( + select `src`.`key`, `src`.`value` from `default`.`src` +) `a` diff --git a/sql/hive/src/test/resources/golden/create_view_translate-6-a14cfe3eff322066e61023ec06c7735d b/sql/hive/src/test/resources/golden/create_view_translate-6-a14cfe3eff322066e61023ec06c7735d new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-7-e947bf2dacc907825df154a4131a3fcc b/sql/hive/src/test/resources/golden/create_view_translate-7-e947bf2dacc907825df154a4131a3fcc new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-8-b1a99b0beffb0b298aec9233ecc0707f b/sql/hive/src/test/resources/golden/create_view_translate-8-b1a99b0beffb0b298aec9233ecc0707f new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sql/hive/src/test/resources/golden/create_view_translate-9-fc0dc39c4796d917685e0797bc4a9786 b/sql/hive/src/test/resources/golden/create_view_translate-9-fc0dc39c4796d917685e0797bc4a9786 new file mode 100644 index 000000000000..e69de29bb2d1 From 65357f11c25a7c91577df5da31ebf349d7845eef Mon Sep 17 00:00:00 2001 From: scwf Date: Tue, 30 Dec 2014 11:30:47 -0800 Subject: [PATCH 017/215] [SPARK-4975][SQL] Fix HiveInspectorSuite test failure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HiveInspectorSuite test failure: [info] - wrap / unwrap null, constant null and writables *** FAILED *** (21 milliseconds) [info] 1 did not equal 0 (HiveInspectorSuite.scala:136) this is because the origin date(is 3914-10-23) not equals the date returned by ```unwrap```(is 3914-10-22). Setting TimeZone and Locale fix this. Another minor change here is rename ```def checkValues(v1: Any, v2: Any): Unit``` to ```def checkValue(v1: Any, v2: Any): Unit ``` to make the code more clear Author: scwf Author: Fei Wang Closes #3814 from scwf/fix-inspectorsuite and squashes the following commits: d8531ef [Fei Wang] Delete test.log 72b19a9 [scwf] fix HiveInspectorSuite test error --- .../spark/sql/hive/HiveInspectorSuite.scala | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index bfe608a51a30..f90d3607915a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.sql.Date import java.util +import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.serde2.io.DoubleWritable import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory @@ -63,6 +64,11 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { .get()) } + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + val data = Literal(true) :: Literal(0.asInstanceOf[Byte]) :: @@ -121,11 +127,11 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = { row1.zip(row2).map { - case (r1, r2) => checkValues(r1, r2) + case (r1, r2) => checkValue(r1, r2) } } - def checkValues(v1: Any, v2: Any): Unit = { + def checkValue(v1: Any, v2: Any): Unit = { (v1, v2) match { case (r1: Decimal, r2: Decimal) => // Ignore the Decimal precision @@ -195,26 +201,26 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { }) checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row]) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) val d = row(0) :: row(0) :: Nil - checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) val d = Map(row(0) -> row(1)) - checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) - checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) - checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) + checkValue(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt)))) } } From 5595eaa74f139fdb6fd8a7bb0ca6ed421ef00ac8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 30 Dec 2014 11:33:47 -0800 Subject: [PATCH 018/215] [SPARK-4959] [SQL] Attributes are case sensitive when using a select query from a projection Author: Cheng Hao Closes #3796 from chenghao-intel/spark_4959 and squashes the following commits: 3ec08f8 [Cheng Hao] Replace the attribute in comparing its exprId other than itself --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- .../sql/hive/execution/HiveTableScanSuite.scala | 14 +++++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 806c1394eb15..0f2eae6400d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,16 +142,16 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList1, Project(projectList2, child)) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). - val aliasMap = projectList2.collect { - case a @ Alias(e, _) => (a.toAttribute: Expression, a) - }.toMap + val aliasMap = AttributeMap(projectList2.collect { + case a @ Alias(e, _) => (a.toAttribute, a) + }) // Substitute any attributes that are produced by the child projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' // TODO: Fix TransformBase to avoid the cast below. val substitutedProjection = projectList1.map(_.transform { - case a if aliasMap.contains(a) => aliasMap(a) + case a: Attribute if aliasMap.contains(a) => aliasMap(a) }).asInstanceOf[Seq[NamedExpression]] Project(substitutedProjection, child) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index a0ace91060a2..16f77a438e1a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -76,4 +77,15 @@ class HiveTableScanSuite extends HiveComparisonTest { === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } + + test("Spark-4959 Attributes are case sensitive when using a select query from a projection") { + sql("create table spark_4959 (col1 string)") + sql("""insert into table spark_4959 select "hi" from src limit 1""") + table("spark_4959").select( + 'col1.as('CaseSensitiveColName), + 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + + assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + } } From 63b84b7d6785a687dd7f4c0e2bb1e348800d30d8 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 30 Dec 2014 11:47:08 -0800 Subject: [PATCH 019/215] [SPARK-4904] [SQL] Remove the unnecessary code change in Generic UDF Since #3429 has been merged, the bug of wrapping to Writable for HiveGenericUDF is resolved, we can safely remove the foldable checking in `HiveGenericUdf.eval`, which discussed in #2802. Author: Cheng Hao Closes #3745 from chenghao-intel/generic_udf and squashes the following commits: 622ad03 [Cheng Hao] Remove the unnecessary code change in Generic UDF --- .../src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 93b6ef9fbc59..7d863f9d89da 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -158,11 +158,6 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr override def foldable = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected def constantReturnValue = unwrap( - returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(), - returnInspector) - @transient protected lazy val deferedObjects = argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] @@ -171,7 +166,6 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr override def eval(input: Row): Any = { returnInspector // Make sure initialized. - if(foldable) return constantReturnValue var i = 0 while (i < children.length) { From daac221302e0cf71a7b7bda31625134cf7b9dce1 Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 30 Dec 2014 12:07:24 -0800 Subject: [PATCH 020/215] [SPARK-5002][SQL] Using ascending by default when not specify order in order by spark sql does not support ```SELECT a, b FROM testData2 ORDER BY a desc, b```. Author: wangfei Closes #3838 from scwf/orderby and squashes the following commits: 114b64a [wangfei] remove nouse methods 48145d3 [wangfei] fix order, using asc by default --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 8 ++------ .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4fc9bbfd311..66860a4c0923 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -209,15 +209,11 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(singleOrder, ",") - | rep1sep(expression, ",") ~ direction.? ^^ { - case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + ( rep1sep(expression ~ direction.? , ",") ^^ { + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) - protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } - protected lazy val direction: Parser[SortDirection] = ( ASC ^^^ Ascending | DESC ^^^ Descending diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ddf4776ecf7a..add4e218a22e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -987,6 +987,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) } + test("oder by asc by default when not specify ascending and descending") { + checkAnswer( + sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), + Seq((3, 1), (3, 2), (2, 1), (2,2), (1, 1), (1, 2)) + ) + } + test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) From 53f0a00b6051fb6cb52a90f91ae01bcd77e332c5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 30 Dec 2014 12:11:44 -0800 Subject: [PATCH 021/215] [Spark-4512] [SQL] Unresolved Attribute Exception in Sort By It will cause exception while do query like: SELECT key+key FROM src sort by value; Author: Cheng Hao Closes #3386 from chenghao-intel/sort and squashes the following commits: 38c78cc [Cheng Hao] revert the SortPartition in SparkStrategies 7e9dd15 [Cheng Hao] update the typo fcd1d64 [Cheng Hao] rebase the latest master and update the SortBy unit test --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 ++-- .../sql/catalyst/analysis/Analyzer.scala | 13 +++++++------ .../spark/sql/catalyst/dsl/package.scala | 4 ++-- .../plans/logical/basicOperators.scala | 11 ++++++++++- .../org/apache/spark/sql/SchemaRDD.scala | 5 ++--- .../spark/sql/execution/SparkStrategies.scala | 11 +++++------ .../org/apache/spark/sql/DslQuerySuite.scala | 19 ++++++++++++++----- .../scala/org/apache/spark/sql/TestData.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 8 ++++---- .../hive/execution/HiveComparisonTest.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 7 +++++++ 11 files changed, 55 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 66860a4c0923..f79d4ff444dc 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -204,8 +204,8 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] = - ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, l) } - | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => SortPartitions(o, l) } + ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) } + | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) } ) protected lazy val ordering: Parser[Seq[SortOrder]] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1c4088b8438e..72680f37a0b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -246,7 +246,7 @@ class Analyzer(catalog: Catalog, case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. - case p@Project(projectList, child) if containsStar(projectList) => + case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) @@ -310,7 +310,8 @@ class Analyzer(catalog: Catalog, */ object ResolveSortReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + case s @ Sort(ordering, global, p @ Project(projectList, child)) + if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) val resolved = unresolved.flatMap(child.resolve(_, resolver)) val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) @@ -319,13 +320,14 @@ class Analyzer(catalog: Catalog, if (missingInProject.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(projectList.map(_.toAttribute), - Sort(ordering, + Sort(ordering, global, Project(projectList ++ missingInProject, child))) } else { logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } - case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) + if !s.resolved && a.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) // A small hack to create an object that will allow us to resolve any references that // refer to named expressions that are present in the grouping expressions. @@ -340,8 +342,7 @@ class Analyzer(catalog: Catalog, if (missingInAggs.nonEmpty) { // Add missing grouping exprs and then project them away after the sort. Project(a.output, - Sort(ordering, - Aggregate(grouping, aggs ++ missingInAggs, child))) + Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child))) } else { s // Nothing we can do here. Return original plan. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fb252cdf5153..a14e5b9ef14d 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -244,9 +244,9 @@ package object dsl { condition: Option[Expression] = None) = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan) + def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = SortPartitions(sortExprs, logicalPlan) + def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { val aliasedExprs = aggregateExprs.map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a9282b98adfa..0b9f01cbae9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -130,7 +130,16 @@ case class WriteToFile( override def output = child.output } -case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { +/** + * @param order The ordering expressions + * @param global True means global sorting apply for entire data set, + * False means sorting only apply within the partition. + * @param child Child logical plan + */ +case class Sort( + order: Seq[SortOrder], + global: Boolean, + child: LogicalPlan) extends UnaryNode { override def output = child.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 856b10f1a8fd..80787b61ce1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -214,7 +214,7 @@ class SchemaRDD( * @group Query */ def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) /** * Sorts the results by the given expressions within partition. @@ -227,7 +227,7 @@ class SchemaRDD( * @group Query */ def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, SortPartitions(sortExprs, logicalPlan)) + new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) @deprecated("use limit with integer argument", "1.1.0") def limit(limitExpr: Expression): SchemaRDD = @@ -238,7 +238,6 @@ class SchemaRDD( * {{{ * schemaRDD.limit(10) * }}} - * * @group Query */ def limit(limitNum: Int): SchemaRDD = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2954d4ce7d2d..9151da69ed44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -190,7 +190,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } @@ -257,15 +257,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - case logical.Sort(sortExprs, child) if sqlContext.externalSortEnabled => - execution.ExternalSort(sortExprs, global = true, planLater(child)):: Nil - case logical.Sort(sortExprs, child) => - execution.Sort(sortExprs, global = true, planLater(child)):: Nil - case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. execution.Sort(sortExprs, global = false, planLater(child)) :: Nil + case logical.Sort(sortExprs, global, child) if sqlContext.externalSortEnabled => + execution.ExternalSort(sortExprs, global, planLater(child)):: Nil + case logical.Sort(sortExprs, global, child) => + execution.Sort(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 691c4b38287b..c0b9cf516312 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -88,7 +88,7 @@ class DslQuerySuite extends QueryTest { Seq(Seq(6))) } - test("sorting") { + test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) @@ -122,22 +122,31 @@ class DslQuerySuite extends QueryTest { mapData.collect().sortBy(_.data(1)).reverse.toSeq) } - test("sorting #2") { + test("partition wide sorting") { + // 2 partitions totally, and + // Partition #1 with values: + // (1, 1) + // (1, 2) + // (2, 1) + // Partition #2 with values: + // (2, 2) + // (3, 1) + // (3, 2) checkAnswer( testData2.sortBy('a.asc, 'b.asc), Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) checkAnswer( testData2.sortBy('a.asc, 'b.desc), - Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1))) checkAnswer( testData2.sortBy('a.desc, 'b.desc), - Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2))) checkAnswer( testData2.sortBy('a.desc, 'b.asc), - Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2))) } test("limit") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index bb553a0a1e50..497897c3c0d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -55,7 +55,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toSchemaRDD testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3f3d9e7cd4fb..8a9613cf96e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -680,16 +680,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving) + Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withHaving) case (None, Some(perPartitionOrdering), None, None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving) + Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), + Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => - SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), + Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 8011f9b8773b..4104df8f8e02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -132,7 +132,7 @@ abstract class HiveComparisonTest def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case PhysicalOperation(_, _, Sort(_, _)) => true + case PhysicalOperation(_, _, Sort(_, true, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f57f31af1556..5d0fb7237011 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -32,6 +32,13 @@ case class Nested3(f3: Int) * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { + checkAnswer( + sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"), + sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq + ) + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect sql( From 19a8802e703e6b075a148ba73dc9dd80748d6322 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 30 Dec 2014 12:16:45 -0800 Subject: [PATCH 022/215] [SPARK-4493][SQL] Tests for IsNull / IsNotNull in the ParquetFilterSuite This is a follow-up of #3367 and #3644. At the time #3644 was written, #3367 hadn't been merged yet, thus `IsNull` and `IsNotNull` filters are not covered in the first version of `ParquetFilterSuite`. This PR adds corresponding test cases. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3748) Author: Cheng Lian Closes #3748 from liancheng/test-null-filters and squashes the following commits: 1ab943f [Cheng Lian] IsNull and IsNotNull Parquet filter test case for boolean type bcd616b [Cheng Lian] Adds Parquet filter pushedown tests for IsNull and IsNotNull --- .../sql/parquet/ParquetFilterSuite.scala | 60 +++++++++++++++---- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index b17300475b6f..4c3a04506ce4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -28,11 +28,14 @@ import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. * - * Notice that `!(a cmp b)` are always transformed to its negated form `a cmp' b` by the - * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` - * results a `GtEq` filter predicate rather than a `Not`. + * NOTE: * - * @todo Add test cases for `IsNull` and `IsNotNull` after merging PR #3367 + * 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the + * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` + * results in a `GtEq` filter predicate rather than a `Not`. + * + * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred + * data type is nullable. */ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext @@ -85,14 +88,26 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(Tuple1.apply)) { rdd => + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Boolean]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Boolean]]) { + Seq(Row(true), Row(false)) + } + checkFilterPushdown(rdd, '_1)('_1 === true, classOf[Eq[java.lang.Boolean]])(true) - checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]])(false) + checkFilterPushdown(rdd, '_1)('_1 !== true, classOf[Operators.NotEq[java.lang.Boolean]]) { + false + } } } test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(Tuple1.apply)) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[Integer]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[Integer]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[Integer]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[Integer]]) { (2 to 4).map(Row.apply(_)) @@ -118,7 +133,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toLong))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Long]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Long]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Long]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Long]]) { (2 to 4).map(Row.apply(_)) @@ -144,7 +164,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toFloat))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Float]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Float]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Float]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Float]]) { (2 to 4).map(Row.apply(_)) @@ -170,7 +195,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toDouble))) { rdd => + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.Double]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.Double]]) { + (1 to 4).map(Row.apply(_)) + } + checkFilterPushdown(rdd, '_1)('_1 === 1, classOf[Eq[java.lang.Double]])(1) checkFilterPushdown(rdd, '_1)('_1 !== 1, classOf[Operators.NotEq[java.lang.Double]]) { (2 to 4).map(Row.apply(_)) @@ -197,6 +227,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - string") { withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { rdd => + checkFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + (1 to 4).map(i => Row.apply(i.toString)) + } + checkFilterPushdown(rdd, '_1)('_1 === "1", classOf[Eq[String]])("1") checkFilterPushdown(rdd, '_1)('_1 !== "1", classOf[Operators.NotEq[String]]) { (2 to 4).map(i => Row.apply(i.toString)) @@ -227,6 +262,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { rdd => + checkBinaryFilterPushdown(rdd, '_1)('_1.isNull, classOf[Eq[java.lang.String]])(Seq.empty[Row]) + checkBinaryFilterPushdown(rdd, '_1)('_1.isNotNull, classOf[NotEq[java.lang.String]]) { + (1 to 4).map(i => Row.apply(i.b)).toSeq + } + checkBinaryFilterPushdown(rdd, '_1)('_1 === 1.b, classOf[Eq[Array[Byte]]])(1.b) checkBinaryFilterPushdown(rdd, '_1)('_1 !== 1.b, classOf[Operators.NotEq[Array[Byte]]]) { (2 to 4).map(i => Row.apply(i.b)).toSeq From f7a41a0e79561a722e41800257dca886732ccaad Mon Sep 17 00:00:00 2001 From: luogankun Date: Tue, 30 Dec 2014 12:17:49 -0800 Subject: [PATCH 023/215] [SPARK-4916][SQL][DOCS]Update SQL programming guide about cache section `SchemeRDD.cache()` now uses in-memory columnar storage. Author: luogankun Closes #3759 from luogankun/SPARK-4916 and squashes the following commits: 7b39864 [luogankun] [SPARK-4916]Update SQL programming guide 6018122 [luogankun] Merge branch 'master' of https://github.com/apache/spark into SPARK-4916 0b93785 [luogankun] [SPARK-4916]Update SQL programming guide 99b2336 [luogankun] [SPARK-4916]Update SQL programming guide --- docs/sql-programming-guide.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2aea8a8aedaf..1b5fde991e40 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -831,13 +831,10 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Note that if you call `schemaRDD.cache()` rather than `sqlContext.cacheTable(...)`, tables will _not_ be cached using -the in-memory columnar format, and therefore `sqlContext.cacheTable(...)` is strongly recommended for this use case. - Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running `SET key=value` commands using SQL. From 2deac748b4e1245c2cb9bd43ad87c80d6d130a83 Mon Sep 17 00:00:00 2001 From: luogankun Date: Tue, 30 Dec 2014 12:18:55 -0800 Subject: [PATCH 024/215] [SPARK-4930][SQL][DOCS]Update SQL programming guide, CACHE TABLE is eager `CACHE TABLE tbl` is now __eager__ by default not __lazy__ Author: luogankun Closes #3773 from luogankun/SPARK-4930 and squashes the following commits: cc17b7d [luogankun] [SPARK-4930][SQL][DOCS]Update SQL programming guide, add CACHE [LAZY] TABLE [AS SELECT] ... bffe0e8 [luogankun] [SPARK-4930][SQL][DOCS]Update SQL programming guide, CACHE TABLE tbl is eager --- docs/sql-programming-guide.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 1b5fde991e40..729045b81a8c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1007,12 +1007,11 @@ let user control table caching explicitly: CACHE TABLE logs_last_month; UNCACHE TABLE logs_last_month; -**NOTE:** `CACHE TABLE tbl` is lazy, similar to `.cache` on an RDD. This command only marks `tbl` to ensure that -partitions are cached when calculated but doesn't actually cache it until a query that touches `tbl` is executed. -To force the table to be cached, you may simply count the table immediately after executing `CACHE TABLE`: +**NOTE:** `CACHE TABLE tbl` is now __eager__ by default not __lazy__. Don’t need to trigger cache materialization manually anymore. - CACHE TABLE logs_last_month; - SELECT COUNT(1) FROM logs_last_month; +Spark SQL newly introduced a statement to let user control table caching whether or not lazy since Spark 1.2.0: + + CACHE [LAZY] TABLE [AS SELECT] ... Several caching related features are not supported yet: From a75dd83b72586695768c89ed32b240aa8f48f32c Mon Sep 17 00:00:00 2001 From: guowei2 Date: Tue, 30 Dec 2014 12:21:00 -0800 Subject: [PATCH 025/215] [SPARK-4928][SQL] Fix: Operator '>,<,>=,<=' with decimal between different precision report error case operator with decimal between different precision, we need change them to unlimited Author: guowei2 Closes #3767 from guowei2/SPARK-4928 and squashes the following commits: c6a6e3e [guowei2] fix code style 3214e0a [guowei2] add test case b4985a2 [guowei2] fix code style 27adf42 [guowei2] Fix: Operation '>,<,>=,<=' with Decimal report error --- .../catalyst/analysis/HiveTypeCoercion.scala | 16 ++++++++++++++++ .../analysis/DecimalPrecisionSuite.scala | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e38114ab3cf2..242f28f67029 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -361,6 +361,22 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case LessThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThan(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + + case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b: BinaryExpression if b.left.dataType != b.right.dataType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index d5b7d2789a10..3677a6e72e23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -49,6 +49,15 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { assert(analyzer(plan).schema.fields(0).dataType === expectedType) } + private def checkComparison(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Alias(expression, "c")() :: Nil, relation) + val comparison = analyzer(plan).collect { + case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e + }.head + assert(comparison.left.dataType === expectedType) + assert(comparison.right.dataType === expectedType) + } + test("basic operations") { checkType(Add(d1, d2), DecimalType(6, 2)) checkType(Subtract(d1, d2), DecimalType(6, 2)) @@ -65,6 +74,14 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) } + test("Comparison operations") { + checkComparison(LessThan(i, d1), DecimalType.Unlimited) + checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) + checkComparison(GreaterThanOrEqual(d1, f), DoubleType) + checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) + } + test("bringing in primitive types") { checkType(Add(d1, i), DecimalType(12, 1)) checkType(Add(d1, f), DoubleType) From 61a99f6a11d85e931e7d60f9ab4370b3b40a52ef Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 30 Dec 2014 13:38:27 -0800 Subject: [PATCH 026/215] [SPARK-4937][SQL] Normalizes conjunctions and disjunctions to eliminate common predicates This PR is a simplified version of several filter optimization rules introduced in #3778 authored by scwf. Newly introduced optimizations include: 1. `a && a` => `a` 2. `a || a` => `a` 3. `(a || b || c || ...) && (a || b || d || ...)` => `a && b && (c || d || ...)` The 3rd rule is particularly useful for optimizing the following query, which is planned into a cartesian product ```sql SELECT * FROM t1, t2 WHERE (t1.key = t2.key AND t1.value > 10) OR (t1.key = t2.key AND t2.value < 20) ``` to the following one, which is planned into an equi-join: ```sql SELECT * FROM t1, t2 WHERE t1.key = t2.key AND (t1.value > 10 OR t2.value < 20) ``` The example above is quite artificial, but common predicates are likely to appear in real life complex queries (like the one mentioned in #3778). A difference between this PR and #3778 is that these optimizations are not limited to `Filter`, but are generalized to all logical plan nodes. Thanks to scwf for bringing up these optimizations, and chenghao-intel for the generalization suggestion. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3784) Author: Cheng Lian Closes #3784 from liancheng/normalize-filters and squashes the following commits: caca560 [Cheng Lian] Moves filter normalization into BooleanSimplification rule 4ab3a58 [Cheng Lian] Fixes test failure, adds more tests 5d54349 [Cheng Lian] Fixes typo in comment 2abbf8e [Cheng Lian] Forgot our sacred Apache licence header... cf95639 [Cheng Lian] Adds an optimization rule for filter normalization --- .../sql/catalyst/expressions/predicates.scala | 9 ++- .../sql/catalyst/optimizer/Optimizer.scala | 27 +++++-- .../optimizer/NormalizeFiltersSuite.scala | 72 +++++++++++++++++++ .../columnar/PartitionBatchPruningSuite.scala | 10 ++- 4 files changed, 110 insertions(+), 8 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 94b6fb084d38..cb5ff6795986 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType @@ -48,6 +47,14 @@ trait PredicateHelper { } } + protected def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case Or(cond1, cond2) => + splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2) + case other => other :: Nil + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when is is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f2eae6400d2..cd3137980ca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -294,11 +294,16 @@ object OptimizeIn extends Rule[LogicalPlan] { } /** - * Simplifies boolean expressions where the answer can be determined without evaluating both sides. + * Simplifies boolean expressions: + * + * 1. Simplifies expressions whose answer can be determined without evaluating both sides. + * 2. Eliminates / extracts common factors. + * 3. Removes `Not` operator. + * * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus * is only safe when evaluations of expressions does not result in side effects. */ -object BooleanSimplification extends Rule[LogicalPlan] { +object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case and @ And(left, right) => @@ -307,7 +312,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(true, BooleanType)) => l case (Literal(false, BooleanType), _) => Literal(false) case (_, Literal(false, BooleanType)) => Literal(false) - case (_, _) => and + // a && a && a ... => a + case _ if splitConjunctivePredicates(and).distinct.size == 1 => left + case _ => and } case or @ Or(left, right) => @@ -316,7 +323,19 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, Literal(true, BooleanType)) => Literal(true) case (Literal(false, BooleanType), r) => r case (l, Literal(false, BooleanType)) => l - case (_, _) => or + // a || a || a ... => a + case _ if splitDisjunctivePredicates(or).distinct.size == 1 => left + // (a && b && c && ...) || (a && b && d && ...) => a && b && (c || d || ...) + case _ => + val lhsSet = splitConjunctivePredicates(left).toSet + val rhsSet = splitConjunctivePredicates(right).toSet + val common = lhsSet.intersect(rhsSet) + + (lhsSet.diff(common).reduceOption(And) ++ rhsSet.diff(common).reduceOption(And)) + .reduceOption(Or) + .map(_ :: common.toList) + .getOrElse(common.toList) + .reduce(And) } case not @ Not(exp) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala new file mode 100644 index 000000000000..906300d8336c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFiltersSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class NormalizeFiltersSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq( + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators), + Batch("NormalizeFilters", FixedPoint(100), + BooleanSimplification, + SimplifyFilters)) + } + + val relation = LocalRelation('a.int, 'b.int, 'c.string) + + def checkExpression(original: Expression, expected: Expression): Unit = { + val actual = Optimize(relation.where(original)).collect { case f: Filter => f.condition }.head + val result = (actual, expected) match { + case (And(l1, r1), And(l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (Or (l1, r1), Or (l2, r2)) => (l1 == l2 && r1 == r2) || (l1 == r2 && l2 == r1) + case (lhs, rhs) => lhs fastEquals rhs + } + + assert(result, s"$actual isn't equivalent to $expected") + } + + test("a && a => a") { + checkExpression('a === 1 && 'a === 1, 'a === 1) + checkExpression('a === 1 && 'a === 1 && 'a === 1, 'a === 1) + } + + test("a || a => a") { + checkExpression('a === 1 || 'a === 1, 'a === 1) + checkExpression('a === 1 || 'a === 1 || 'a === 1, 'a === 1) + } + + test("(a && b) || (a && c) => a && (b || c)") { + checkExpression( + ('a === 1 && 'a < 10) || ('a > 2 && 'a === 1), + ('a === 1) && ('a < 10 || 'a > 2)) + + checkExpression( + ('a < 1 && 'b > 2 && 'c.isNull) || ('a < 1 && 'c === "hello" && 'b > 2), + ('c.isNull || 'c === "hello") && 'a < 1 && 'b > 2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 82afa31a99a7..1915c25392f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -105,7 +105,9 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be test(query) { val schemaRdd = sql(query) - assertResult(expectedQueryResult.toArray, "Wrong query result") { + val queryExecution = schemaRdd.queryExecution + + assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { schemaRdd.collect().map(_.head).toArray } @@ -113,8 +115,10 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head - assert(readBatches === expectedReadBatches, "Wrong number of read batches") - assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + assert(readBatches === expectedReadBatches, s"Wrong number of read batches: $queryExecution") + assert( + readPartitions === expectedReadPartitions, + s"Wrong number of read partitions: $queryExecution") } } } From 7425bec320227bf8818dc2844c12d5373d166364 Mon Sep 17 00:00:00 2001 From: Michael Davies Date: Tue, 30 Dec 2014 13:40:51 -0800 Subject: [PATCH 027/215] [SPARK-4386] Improve performance when writing Parquet files Convert type of RowWriteSupport.attributes to Array. Analysis of performance for writing very wide tables shows that time is spent predominantly in apply method on attributes var. Type of attributes previously was LinearSeqOptimized and apply is O(N) which made write O(N squared). Measurements on 575 column table showed this change made a 6x improvement in write times. Author: Michael Davies Closes #3843 from MickDavies/SPARK-4386 and squashes the following commits: 892519d [Michael Davies] [SPARK-4386] Improve performance when writing Parquet files --- .../org/apache/spark/sql/parquet/ParquetTableSupport.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index ef3687e69296..9049eb5932b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -130,7 +130,7 @@ private[parquet] object RowReadSupport { private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] var writer: RecordConsumer = null - private[parquet] var attributes: Seq[Attribute] = null + private[parquet] var attributes: Array[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) @@ -138,7 +138,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) if (attributes == null) { - attributes = ParquetTypesConverter.convertFromString(origAttributesStr) + attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray } log.debug(s"write support initialized for requested schema $attributes") From 8f29b7cafc2b6e802e4eb21f681d6369da2f30fa Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 30 Dec 2014 13:44:30 -0800 Subject: [PATCH 028/215] [SPARK-4935][SQL] When hive.cli.print.header configured, spark-sql aborted if passed in a invalid sql If we passed in a wrong sql like ```abdcdfsfs```, the spark-sql script aborted. Author: wangfei Author: Fei Wang Closes #3761 from scwf/patch-10 and squashes the following commits: 46dc344 [Fei Wang] revert console.printError(rc.getErrorMessage()) 0330e07 [wangfei] avoid to print error message repeatedly 1614a11 [wangfei] spark-sql abort when passed in a wrong sql --- .../spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 6ed8fd2768f9..7a3d76c61c3a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -60,7 +60,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(0, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) } } From 07fa1910d9c4092d670381c447403105f01c584e Mon Sep 17 00:00:00 2001 From: wangxiaojing Date: Tue, 30 Dec 2014 13:54:12 -0800 Subject: [PATCH 029/215] [SPARK-4570][SQL]add BroadcastLeftSemiJoinHash JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570) We are planning to create a `BroadcastLeftSemiJoinHash` to implement the broadcast join for `left semijoin` In left semijoin : If the size of data from right side is smaller than the user-settable threshold `AUTO_BROADCASTJOIN_THRESHOLD`, the planner would mark it as the `broadcast` relation and mark the other relation as the stream side. The broadcast table will be broadcasted to all of the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use `joins.LeftSemiJoinHash`. The benchmark suggests these made the optimized version 4x faster when `left semijoin`

Original:
left semi join : 9288 ms
Optimized:
left semi join : 1963 ms
The micro benchmark load `data1/kv3.txt` into a normal Hive table. Benchmark code:

 def benchmark(f: => Unit) = {
    val begin = System.currentTimeMillis()
    f
    val end = System.currentTimeMillis()
    end - begin
  }
  val sc = new SparkContext(
    new SparkConf()
      .setMaster("local")
      .setAppName(getClass.getSimpleName.stripSuffix("$")))
  val hiveContext = new HiveContext(sc)
  import hiveContext._
  sql("drop table if exists left_table")
  sql("drop table if exists right_table")
  sql( """create table left_table (key int, value string)
       """.stripMargin)
  sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""")
  sql( """create table right_table (key int, value string)
       """.stripMargin)
  sql(
    """
      |from left_table
      |insert overwrite table right_table
      |select left_table.key, left_table.value
    """.stripMargin)

  val leftSimeJoin = sql(
    """select a.key from left_table a
      |left semi join right_table b on a.key = b.key""".stripMargin)
  val leftSemiJoinDuration = benchmark(leftSimeJoin.count())
  println(s"left semi join : $leftSemiJoinDuration ms ")
Author: wangxiaojing Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits: a4a43c9 [wangxiaojing] rebase f103983 [wangxiaojing] change style fbe4887 [wangxiaojing] change style ff2e618 [wangxiaojing] add testsuite 1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash --- .../spark/sql/execution/SparkStrategies.scala | 6 ++ .../joins/BroadcastLeftSemiJoinHash.scala | 67 +++++++++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 38 +++++++++++ .../spark/sql/hive/StatisticsSuite.scala | 50 +++++++++++++- 4 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9151da69ed44..ce878c137e62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + val semiJoin = joins.BroadcastLeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => val semiJoin = joins.LeftSemiJoinHash( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala new file mode 100644 index 000000000000..2ab064fd0151 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class BroadcastLeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def output = left.output + + override def execute() = { + val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val broadcastedRelation = sparkContext.broadcast(hashSet) + + streamedPlan.execute().mapPartitions { streamIter => + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 0378fd7e367f..1a4232dab86e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j + case j: BroadcastLeftSemiJoinHash => j } assert(operators.size === 1) @@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { """.stripMargin), (null, 10) :: Nil) } + + test("broadcasted left semi join operator selection") { + clearCache() + sql("CACHE TABLE testData") + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + sql("UNCACHE TABLE testData") + } + + test("left semi join") { + val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + checkAnswer(rdd, + (1, 1) :: + (1, 2) :: + (2, 1) :: + (2, 2) :: + (3, 1) :: + (3, 2) :: Nil) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ff4071d8e2f1..4b6a9308b981 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ @@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ) } + test("auto converts to broadcast left semi join, by size estimate of a relation") { + val leftSemiJoinQuery = + """SELECT * FROM src a + |left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin + val answer = (86, "val_86") :: Nil + + var rdd = sql(leftSemiJoinQuery) + + // Assert src has a size smaller than the threshold. + val sizes = rdd.queryExecution.analyzed.collect { + case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass + .isAssignableFrom(r.getClass) => + r.statistics.sizeInBytes + } + assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold + && sizes(0) <= autoBroadcastJoinThreshold, + s"query should contain two relations, each of which has size smaller than autoConvertSize") + + // Using `sparkPlan` because for relevant patterns in HashJoin to be + // matched, other strategies need to be applied. + var bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.size === 1, + s"actual query plans do not contain broadcast join: ${rdd.queryExecution}") + + checkAnswer(rdd, answer) // check correctness of output + + TestHive.settings.synchronized { + val tmp = autoBroadcastJoinThreshold + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + rdd = sql(leftSemiJoinQuery) + bhj = rdd.queryExecution.sparkPlan.collect { + case j: BroadcastLeftSemiJoinHash => j + } + assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") + + val shj = rdd.queryExecution.sparkPlan.collect { + case j: LeftSemiJoinHash => j + } + assert(shj.size === 1, + "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") + + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + } + + } } From b239ea1c31aeaa752d5dc8f45423df1f8c0924ca Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 30 Dec 2014 14:00:57 -0800 Subject: [PATCH 030/215] SPARK-3955 part 2 [CORE] [HOTFIX] Different versions between jackson-mapper-asl and jackson-core-asl pwendell https://github.com/apache/spark/commit/2483c1efb6429a7d8a20c96d18ce2fec93a1aff9 didn't actually add a reference to `jackson-core-asl` as intended, but a second redundant reference to `jackson-mapper-asl`, as markhamstra picked up on (https://github.com/apache/spark/pull/3716#issuecomment-68180192) This just rectifies the typo. I missed it as well; the original PR https://github.com/apache/spark/pull/2818 had it correct and I also didn't see the problem. Author: Sean Owen Closes #3829 from srowen/SPARK-3955 and squashes the following commits: 6cfdc4e [Sean Owen] Actually refer to jackson-core-asl --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index a843af2b22d6..05f59a9b4140 100644 --- a/pom.xml +++ b/pom.xml @@ -827,7 +827,7 @@ org.codehaus.jackson - jackson-mapper-asl + jackson-core-asl ${jackson.version} From 0f31992c61f6662e5347745f6a1ac272a5fd63c9 Mon Sep 17 00:00:00 2001 From: Jakub Dubovsky Date: Tue, 30 Dec 2014 14:19:07 -0800 Subject: [PATCH 031/215] [Spark-4995] Replace Vector.toBreeze.activeIterator with foreachActive New foreachActive method of vector was introduced by SPARK-4431 as more efficient alternative to vector.toBreeze.activeIterator. There are some parts of codebase where it was not yet replaced. dbtsai Author: Jakub Dubovsky Closes #3846 from james64/SPARK-4995-foreachActive and squashes the following commits: 3eb7e37 [Jakub Dubovsky] Scalastyle fix 32fe6c6 [Jakub Dubovsky] activeIterator removed - IndexedRowMatrix.toBreeze 47a4777 [Jakub Dubovsky] activeIterator removed in RowMatrix.toBreeze 90a7d98 [Jakub Dubovsky] activeIterator removed in MLUtils.saveAsLibSVMFile --- .../spark/mllib/linalg/distributed/IndexedRowMatrix.scala | 2 +- .../apache/spark/mllib/linalg/distributed/RowMatrix.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/util/MLUtils.scala | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 5c1acca0ec53..36d8cadd2bdd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -142,7 +142,7 @@ class IndexedRowMatrix( val mat = BDM.zeros[Double](m, n) rows.collect().foreach { case IndexedRow(rowIndex, vector) => val i = rowIndex.toInt - vector.toBreeze.activeIterator.foreach { case (j, v) => + vector.foreachActive { case (j, v) => mat(i, j) = v } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 10a515af8880..a3fca53929ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -588,8 +588,8 @@ class RowMatrix( val n = numCols().toInt val mat = BDM.zeros[Double](m, n) var i = 0 - rows.collect().foreach { v => - v.toBreeze.activeIterator.foreach { case (j, v) => + rows.collect().foreach { vector => + vector.foreachActive { case (j, v) => mat(i, j) = v } i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 1d07b5dab826..da0da0a168c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -154,10 +154,12 @@ object MLUtils { def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => - val featureStrings = features.toBreeze.activeIterator.map { case (i, v) => - s"${i + 1}:$v" + val sb = new StringBuilder(label.toString) + features.foreachActive { case (i, v) => + sb += ' ' + sb ++= s"${i + 1}:$v" } - (Iterator(label) ++ featureStrings).mkString(" ") + sb.mkString } dataStr.saveAsTextFile(dir) } From 6a897829444e2ef273586511f93a40d36e64fb0b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Dec 2014 14:39:13 -0800 Subject: [PATCH 032/215] [SPARK-4813][Streaming] Fix the issue that ContextWaiter didn't handle 'spurious wakeup' Used `Condition` to rewrite `ContextWaiter` because it provides a convenient API `awaitNanos` for timeout. Author: zsxwing Closes #3661 from zsxwing/SPARK-4813 and squashes the following commits: 52247f5 [zsxwing] Add explicit unit type be42bcf [zsxwing] Update as per review suggestion e06bd4f [zsxwing] Fix the issue that ContextWaiter didn't handle 'spurious wakeup' --- .../spark/streaming/ContextWaiter.scala | 63 ++++++++++++++----- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala index a0aeacbc733b..fdbbe2aa6ef0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -17,30 +17,63 @@ package org.apache.spark.streaming +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + private[streaming] class ContextWaiter { + + private val lock = new ReentrantLock() + private val condition = lock.newCondition() + + // Guarded by "lock" private var error: Throwable = null - private var stopped: Boolean = false - def notifyError(e: Throwable) = synchronized { - error = e - notifyAll() - } + // Guarded by "lock" + private var stopped: Boolean = false - def notifyStop() = synchronized { - stopped = true - notifyAll() + def notifyError(e: Throwable): Unit = { + lock.lock() + try { + error = e + condition.signalAll() + } finally { + lock.unlock() + } } - def waitForStopOrError(timeout: Long = -1) = synchronized { - // If already had error, then throw it - if (error != null) { - throw error + def notifyStop(): Unit = { + lock.lock() + try { + stopped = true + condition.signalAll() + } finally { + lock.unlock() } + } - // If not already stopped, then wait - if (!stopped) { - if (timeout < 0) wait() else wait(timeout) + /** + * Return `true` if it's stopped; or throw the reported error if `notifyError` has been called; or + * `false` if the waiting time detectably elapsed before return from the method. + */ + def waitForStopOrError(timeout: Long = -1): Boolean = { + lock.lock() + try { + if (timeout < 0) { + while (!stopped && error == null) { + condition.await() + } + } else { + var nanos = TimeUnit.MILLISECONDS.toNanos(timeout) + while (!stopped && error == null && nanos > 0) { + nanos = condition.awaitNanos(nanos) + } + } + // If already had error, then throw it if (error != null) throw error + // already stopped or timeout + stopped + } finally { + lock.unlock() } } } From 035bac88c732247c79a1bbad4f9191090cbbdc9a Mon Sep 17 00:00:00 2001 From: Liu Jiongzhou Date: Tue, 30 Dec 2014 15:55:56 -0800 Subject: [PATCH 033/215] [SPARK-4998][MLlib]delete the "train" function To make the functions with the same in "object" effective, specially when using java reflection. As the "train" function defined in "class DecisionTree" will hide the functions with the same name in "object DecisionTree". JIRA[SPARK-4998] Author: Liu Jiongzhou Closes #3836 from ljzzju/master and squashes the following commits: 4e13133 [Liu Jiongzhou] [MLlib]delete the "train" function --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 73e7e32c6db3..b3e8ed9af8c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -64,13 +64,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val rfModel = rf.run(input) rfModel.trees(0) } - - /** - * Trains a decision tree model over an RDD. This is deprecated because it hides the static - * methods with the same name in Java. - */ - @deprecated("Please use DecisionTree.run instead.", "1.2.0") - def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { From 352ed6bbe3c3b67e52e298e7c535ae414d96beca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 30 Dec 2014 18:12:20 -0800 Subject: [PATCH 034/215] [SPARK-1010] Clean up uses of System.setProperty in unit tests Several of our tests call System.setProperty (or test code which implicitly sets system properties) and don't always reset/clear the modified properties, which can create ordering dependencies between tests and cause hard-to-diagnose failures. This patch removes most uses of System.setProperty from our tests, since in most cases we can use SparkConf to set these configurations (there are a few exceptions, including the tests of SparkConf itself). For the cases where we continue to use System.setProperty, this patch introduces a `ResetSystemProperties` ScalaTest mixin class which snapshots the system properties before individual tests and to automatically restores them on test completion / failure. See the block comment at the top of the ResetSystemProperties class for more details. Author: Josh Rosen Closes #3739 from JoshRosen/cleanup-system-properties-in-tests and squashes the following commits: 0236d66 [Josh Rosen] Replace setProperty uses in two example programs / tools 3888fe3 [Josh Rosen] Remove setProperty use in LocalJavaStreamingContext 4f4031d [Josh Rosen] Add note on why SparkSubmitSuite needs ResetSystemProperties 4742a5b [Josh Rosen] Clarify ResetSystemProperties trait inheritance ordering. 0eaf0b6 [Josh Rosen] Remove setProperty call in TaskResultGetterSuite. 7a3d224 [Josh Rosen] Fix trait ordering 3fdb554 [Josh Rosen] Remove setProperty call in TaskSchedulerImplSuite bee20df [Josh Rosen] Remove setProperty calls in SparkContextSchedulerCreationSuite 655587c [Josh Rosen] Remove setProperty calls in JobCancellationSuite 3f2f955 [Josh Rosen] Remove System.setProperty calls in DistributedSuite cfe9cce [Josh Rosen] Remove use of system properties in SparkContextSuite 8783ab0 [Josh Rosen] Remove TestUtils.setSystemProperty, since it is subsumed by the ResetSystemProperties trait. 633a84a [Josh Rosen] Remove use of system properties in FileServerSuite 25bfce2 [Josh Rosen] Use ResetSystemProperties in UtilsSuite 1d1aa5a [Josh Rosen] Use ResetSystemProperties in SizeEstimatorSuite dd9492b [Josh Rosen] Use ResetSystemProperties in AkkaUtilsSuite b0daff2 [Josh Rosen] Use ResetSystemProperties in BlockManagerSuite e9ded62 [Josh Rosen] Use ResetSystemProperties in TaskSchedulerImplSuite 5b3cb54 [Josh Rosen] Use ResetSystemProperties in SparkListenerSuite 0995c4b [Josh Rosen] Use ResetSystemProperties in SparkContextSchedulerCreationSuite c83ded8 [Josh Rosen] Use ResetSystemProperties in SparkConfSuite 51aa870 [Josh Rosen] Use withSystemProperty in ShuffleSuite 60a63a1 [Josh Rosen] Use ResetSystemProperties in JobCancellationSuite 14a92e4 [Josh Rosen] Use withSystemProperty in FileServerSuite 628f46c [Josh Rosen] Use ResetSystemProperties in DistributedSuite 9e3e0dd [Josh Rosen] Add ResetSystemProperties test fixture mixin; use it in SparkSubmitSuite. 4dcea38 [Josh Rosen] Move withSystemProperty to TestUtils class. --- .../org/apache/spark/DistributedSuite.scala | 21 ++----- .../org/apache/spark/FileServerSuite.scala | 16 ++--- .../apache/spark/JobCancellationSuite.scala | 21 +++---- .../scala/org/apache/spark/ShuffleSuite.scala | 22 +++---- .../org/apache/spark/SparkConfSuite.scala | 51 ++++++--------- .../SparkContextSchedulerCreationSuite.scala | 31 ++++------ .../org/apache/spark/SparkContextSuite.scala | 62 +++++++------------ .../spark/deploy/SparkSubmitSuite.scala | 6 +- .../spark/scheduler/SparkListenerSuite.scala | 9 +-- .../scheduler/TaskResultGetterSuite.scala | 23 +++---- .../scheduler/TaskSchedulerImplSuite.scala | 6 +- .../spark/storage/BlockManagerSuite.scala | 23 +++---- .../apache/spark/util/AkkaUtilsSuite.scala | 2 +- .../spark/util/ResetSystemProperties.scala | 57 +++++++++++++++++ .../spark/util/SizeEstimatorSuite.scala | 38 +++--------- .../org/apache/spark/util/UtilsSuite.scala | 2 +- .../apache/spark/examples/BroadcastTest.scala | 6 +- .../streaming/LocalJavaStreamingContext.java | 8 ++- .../streaming/LocalJavaStreamingContext.java | 8 ++- .../streaming/LocalJavaStreamingContext.java | 8 ++- .../streaming/LocalJavaStreamingContext.java | 8 ++- .../streaming/LocalJavaStreamingContext.java | 8 ++- .../spark/tools/StoragePerfTester.scala | 12 ++-- 23 files changed, 216 insertions(+), 232 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 998f3008ec0e..97ea3578aa8b 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers @@ -29,16 +28,10 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter - with LocalSparkContext { +class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" - after { - System.clearProperty("spark.reducer.maxMbInFlight") - System.clearProperty("spark.storage.memoryFraction") - } - test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, // this test will hang. Correct behavior is that executors don't crash but fail tasks @@ -84,15 +77,14 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter } test("groupByKey where map output sizes exceed maxMbInFlight") { - System.setProperty("spark.reducer.maxMbInFlight", "1") - sc = new SparkContext(clusterUrl, "test") + val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() assert(groups.length === 16) assert(groups.map(_._2).sum === 2000) - // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block } test("accumulators") { @@ -210,7 +202,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter } test("compute without caching when no partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.0001") sc = new SparkContext(clusterUrl, "test") // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory @@ -218,12 +209,11 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter assert(data.count() === 4000000) assert(data.count() === 4000000) assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") } test("compute when only some partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.01") - sc = new SparkContext(clusterUrl, "test") + val conf = new SparkConf().set("spark.storage.memoryFraction", "0.01") + sc = new SparkContext(clusterUrl, "test", conf) // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions // to make sure that *some* of them do fit though @@ -231,7 +221,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter assert(data.count() === 4000000) assert(data.count() === 4000000) assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") } test("passing environment variables to cluster") { diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 49426545c767..0f49ce4754fb 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -31,10 +31,11 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + def newConf: SparkConf = new SparkConf(loadDefaults = false).set("spark.authenticate", "false") + override def beforeEach() { super.beforeEach() resetSparkContext() - System.setProperty("spark.authenticate", "false") } override def beforeAll() { @@ -52,7 +53,6 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(testTempDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) - System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -74,7 +74,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test("Distributing files locally") { - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -108,7 +108,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { test("Distributing files locally using URL as input") { // addFile("file:///....") - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addFile(new File(tmpFile.toString).toURI.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -122,7 +122,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test") + sc = new SparkContext("local[4]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -133,7 +133,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { @@ -147,7 +147,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1,1)) sc.parallelize(testData).foreach { x => @@ -158,7 +158,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test") + sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1,1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 41ed2bce55ce..7584ae79fc92 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -40,12 +40,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter override def afterEach() { super.afterEach() resetSparkContext() - System.clearProperty("spark.scheduler.mode") } test("local mode, FIFO scheduler") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local[2]", "test") + val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local[2]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -53,10 +52,10 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("local mode, fair scheduler") { - System.setProperty("spark.scheduler.mode", "FAIR") + val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local[2]", "test") + conf.set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[2]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -64,8 +63,8 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("cluster mode, FIFO scheduler") { - System.setProperty("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test") + val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -73,10 +72,10 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter } test("cluster mode, fair scheduler") { - System.setProperty("spark.scheduler.mode", "FAIR") + val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test") + conf.set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 58a96245a9b5..f57921b76831 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -35,19 +35,15 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex conf.set("spark.test.noStageRetry", "true") test("groupByKey without compression") { - try { - System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test", conf) - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) - val groups = pairs.groupByKey(4).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } finally { - System.setProperty("spark.shuffle.compress", "true") - } + val myConf = conf.clone().set("spark.shuffle.compress", "false") + sc = new SparkContext("local", "test", myConf) + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) + val groups = pairs.groupByKey(4).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) } test("shuffle non-zero block size") { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 5d018ea9868a..790976a5ac30 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,27 +19,20 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} +import org.apache.spark.util.ResetSystemProperties import com.esotericsoftware.kryo.Kryo -class SparkConfSuite extends FunSuite with LocalSparkContext { +class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { test("loading from system properties") { - try { - System.setProperty("spark.test.testProperty", "2") - val conf = new SparkConf() - assert(conf.get("spark.test.testProperty") === "2") - } finally { - System.clearProperty("spark.test.testProperty") - } + System.setProperty("spark.test.testProperty", "2") + val conf = new SparkConf() + assert(conf.get("spark.test.testProperty") === "2") } test("initializing without loading defaults") { - try { - System.setProperty("spark.test.testProperty", "2") - val conf = new SparkConf(false) - assert(!conf.contains("spark.test.testProperty")) - } finally { - System.clearProperty("spark.test.testProperty") - } + System.setProperty("spark.test.testProperty", "2") + val conf = new SparkConf(false) + assert(!conf.contains("spark.test.testProperty")) } test("named set methods") { @@ -117,23 +110,17 @@ class SparkConfSuite extends FunSuite with LocalSparkContext { test("nested property names") { // This wasn't supported by some external conf parsing libraries - try { - System.setProperty("spark.test.a", "a") - System.setProperty("spark.test.a.b", "a.b") - System.setProperty("spark.test.a.b.c", "a.b.c") - val conf = new SparkConf() - assert(conf.get("spark.test.a") === "a") - assert(conf.get("spark.test.a.b") === "a.b") - assert(conf.get("spark.test.a.b.c") === "a.b.c") - conf.set("spark.test.a.b", "A.B") - assert(conf.get("spark.test.a") === "a") - assert(conf.get("spark.test.a.b") === "A.B") - assert(conf.get("spark.test.a.b.c") === "a.b.c") - } finally { - System.clearProperty("spark.test.a") - System.clearProperty("spark.test.a.b") - System.clearProperty("spark.test.a.b.c") - } + System.setProperty("spark.test.a", "a") + System.setProperty("spark.test.a.b", "a.b") + System.setProperty("spark.test.a.b.c", "a.b.c") + val conf = new SparkConf() + assert(conf.get("spark.test.a") === "a") + assert(conf.get("spark.test.a.b") === "a.b") + assert(conf.get("spark.test.a.b.c") === "a.b.c") + conf.set("spark.test.a.b", "A.B") + assert(conf.get("spark.test.a") === "a") + assert(conf.get("spark.test.a.b") === "A.B") + assert(conf.get("spark.test.a.b.c") === "a.b.c") } test("register kryo classes through registerKryoClasses") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 0390a2e4f1db..8ae4f243ec1a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -27,10 +27,13 @@ import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { - def createTaskScheduler(master: String): TaskSchedulerImpl = { + def createTaskScheduler(master: String): TaskSchedulerImpl = + createTaskScheduler(master, new SparkConf()) + + def createTaskScheduler(master: String, conf: SparkConf): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val createTaskSchedulerMethod = PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler) val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) @@ -102,19 +105,13 @@ class SparkContextSchedulerCreationSuite } test("local-default-parallelism") { - val defaultParallelism = System.getProperty("spark.default.parallelism") - System.setProperty("spark.default.parallelism", "16") - val sched = createTaskScheduler("local") + val conf = new SparkConf().set("spark.default.parallelism", "16") + val sched = createTaskScheduler("local", conf) sched.backend match { case s: LocalBackend => assert(s.defaultParallelism() === 16) case _ => fail() } - - Option(defaultParallelism) match { - case Some(v) => System.setProperty("spark.default.parallelism", v) - case _ => System.clearProperty("spark.default.parallelism") - } } test("simr") { @@ -155,9 +152,10 @@ class SparkContextSchedulerCreationSuite testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") } - def testMesos(master: String, expectedClass: Class[_]) { + def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { + val conf = new SparkConf().set("spark.mesos.coarse", coarse.toString) try { - val sched = createTaskScheduler(master) + val sched = createTaskScheduler(master, conf) assert(sched.backend.getClass === expectedClass) } catch { case e: UnsatisfiedLinkError => @@ -168,17 +166,14 @@ class SparkContextSchedulerCreationSuite } test("mesos fine-grained") { - System.setProperty("spark.mesos.coarse", "false") - testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend], coarse = false) } test("mesos coarse-grained") { - System.setProperty("spark.mesos.coarse", "true") - testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend], coarse = true) } test("mesos with zookeeper") { - System.setProperty("spark.mesos.coarse", "false") - testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 136202210419..8b3c6871a7b3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -23,55 +23,37 @@ import org.apache.hadoop.io.BytesWritable class SparkContextSuite extends FunSuite with LocalSparkContext { - /** Allows system properties to be changed in tests */ - private def withSystemProperty[T](property: String, value: String)(block: => T): T = { - val originalValue = System.getProperty(property) - try { - System.setProperty(property, value) - block - } finally { - if (originalValue == null) { - System.clearProperty(property) - } else { - System.setProperty(property, originalValue) - } - } - } - test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 - withSystemProperty("spark.driver.allowMultipleContexts", "false") { - val conf = new SparkConf().setAppName("test").setMaster("local") - sc = new SparkContext(conf) - // A SparkContext is already running, so we shouldn't be able to create a second one - intercept[SparkException] { new SparkContext(conf) } - // After stopping the running context, we should be able to create a new one - resetSparkContext() - sc = new SparkContext(conf) - } + val conf = new SparkConf().setAppName("test").setMaster("local") + .set("spark.driver.allowMultipleContexts", "false") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) } test("Can still construct a new SparkContext after failing to construct a previous one") { - withSystemProperty("spark.driver.allowMultipleContexts", "false") { - // This is an invalid configuration (no app name or master URL) - intercept[SparkException] { - new SparkContext(new SparkConf()) - } - // Even though those earlier calls failed, we should still be able to create a new context - sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + val conf = new SparkConf().set("spark.driver.allowMultipleContexts", "false") + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(conf) } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(conf.setMaster("local").setAppName("test")) } test("Check for multiple SparkContexts can be disabled via undocumented debug option") { - withSystemProperty("spark.driver.allowMultipleContexts", "true") { - var secondSparkContext: SparkContext = null - try { - val conf = new SparkConf().setAppName("test").setMaster("local") - sc = new SparkContext(conf) - secondSparkContext = new SparkContext(conf) - } finally { - Option(secondSparkContext).foreach(_.stop()) - } + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set("spark.driver.allowMultipleContexts", "true") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index eb7bd7ab3986..5eda2d41f0e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,11 +23,13 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ResetSystemProperties, Utils} import org.scalatest.FunSuite import org.scalatest.Matchers -class SparkSubmitSuite extends FunSuite with Matchers { +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that neeed to be cleared after tests. +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { def beforeAll() { System.setProperty("spark.testing", "true") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index b276343cb412..24f41bf8cccd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.ResetSystemProperties -class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers - with BeforeAndAfter with BeforeAndAfterAll { +class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter + with BeforeAndAfterAll with ResetSystemProperties { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 @@ -37,10 +38,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc = new SparkContext("local", "SparkListenerSuite") } - override def afterAll() { - System.clearProperty("spark.akka.frameSize") - } - test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 5768a3a733f0..3aab5a156ee7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} import org.apache.spark.storage.TaskResultBlockId /** @@ -55,27 +55,20 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule /** * Tests related to handling task results (both direct and indirect). */ -class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll - with LocalSparkContext { +class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - override def beforeAll { - // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small - // as we can make it) so the tests don't take too long. - System.setProperty("spark.akka.frameSize", "1") - } - - override def afterAll { - System.clearProperty("spark.akka.frameSize") - } + // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small + // as we can make it) so the tests don't take too long. + def conf: SparkConf = new SparkConf().set("spark.akka.frameSize", "1") test("handling results smaller than Akka frame size") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) assert(result === 2) } test("handling results larger than Akka frame size") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) @@ -89,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("task retried if result missing from block manager") { // Set the maximum number of task failures to > 0, so that the task set isn't aborted // after the result is missing. - sc = new SparkContext("local[1,2]", "test") + sc = new SparkContext("local[1,2]", "test", conf) // If this test hangs, it's probably because no resource offers were made after the task // failed. val scheduler: TaskSchedulerImpl = sc.taskScheduler match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 7532da88c606..40aaf9dd1f1e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -162,12 +162,12 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin } test("Fair Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + val conf = new SparkConf().set("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local", "TaskSchedulerImplSuite", conf) val taskScheduler = new TaskSchedulerImpl(sc) val taskSet = FakeTask.createTaskSet(1) - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) schedulableBuilder.buildPools() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5554efbcbadf..ffe6f039145e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,7 +33,7 @@ import akka.util.Timeout import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -44,18 +44,17 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} +import org.apache.spark.util._ -class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter - with PrivateMethodTester { +class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach + with PrivateMethodTester with ResetSystemProperties { private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null var actorSystem: ActorSystem = null var master: BlockManagerMaster = null - var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -79,13 +78,13 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter manager } - before { + override def beforeEach(): Unit = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) this.actorSystem = actorSystem // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") + System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") conf.set("spark.driver.port", boundPort.toString) @@ -100,7 +99,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter SizeEstimator invokePrivate initialize() } - after { + override def afterEach(): Unit = { if (store != null) { store.stop() store = null @@ -113,14 +112,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter actorSystem.awaitTermination() actorSystem = null master = null - - if (oldArch != null) { - conf.set("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - System.clearProperty("spark.test.useCompressedOops") } test("StorageLevel object caching") { diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 7bca1711ae22..6bbf72e929dc 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManagerId /** * Test the AkkaUtils with various security settings. */ -class AkkaUtilsSuite extends FunSuite with LocalSparkContext { +class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { test("remote fetch security bad password") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala new file mode 100644 index 000000000000..d4b92f33dd9e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Properties + +import org.scalatest.{BeforeAndAfterEach, Suite} + +/** + * Mixin for automatically resetting system properties that are modified in ScalaTest tests. + * This resets the properties after each individual test. + * + * The order in which fixtures are mixed in affects the order in which they are invoked by tests. + * If we have a suite `MySuite extends FunSuite with Foo with Bar`, then + * Bar's `super` is Foo, so Bar's beforeEach() will and afterEach() methods will be invoked first + * by the rest runner. + * + * This means that ResetSystemProperties should appear as the last trait in test suites that it's + * mixed into in order to ensure that the system properties snapshot occurs as early as possible. + * ResetSystemProperties calls super.afterEach() before performing its own cleanup, ensuring that + * the old properties are restored as late as possible. + * + * See the "Composing fixtures by stacking traits" section at + * http://www.scalatest.org/user_guide/sharing_fixtures for more details about this pattern. + */ +private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Suite => + var oldProperties: Properties = null + + override def beforeEach(): Unit = { + oldProperties = new Properties(System.getProperties) + super.beforeEach() + } + + override def afterEach(): Unit = { + try { + super.afterEach() + } finally { + System.setProperties(oldProperties) + oldProperties = null + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 0ea2d13a8350..7424c2e91d4f 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.util -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite -import org.scalatest.PrivateMethodTester +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -46,20 +44,12 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + extends FunSuite with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { - var oldArch: String = _ - var oldOops: String = _ - - override def beforeAll() { + override def beforeEach() { // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - } - - override def afterAll() { - resetOrClear("os.arch", oldArch) - resetOrClear("spark.test.useCompressedOops", oldOops) + System.setProperty("os.arch", "amd64") + System.setProperty("spark.test.useCompressedOops", "true") } test("simple classes") { @@ -122,7 +112,7 @@ class SizeEstimatorSuite } test("32-bit arch") { - val arch = System.setProperty("os.arch", "x86") + System.setProperty("os.arch", "x86") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -131,14 +121,13 @@ class SizeEstimatorSuite assertResult(48)(SizeEstimator.estimate(DummyString("a"))) assertResult(48)(SizeEstimator.estimate(DummyString("ab"))) assertResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) - resetOrClear("os.arch", arch) } // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { - val arch = System.setProperty("os.arch", "amd64") - val oops = System.setProperty("spark.test.useCompressedOops", "false") + System.setProperty("os.arch", "amd64") + System.setProperty("spark.test.useCompressedOops", "false") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -146,16 +135,5 @@ class SizeEstimatorSuite assertResult(64)(SizeEstimator.estimate(DummyString("a"))) assertResult(64)(SizeEstimator.estimate(DummyString("ab"))) assertResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) - - resetOrClear("os.arch", arch) - resetOrClear("spark.test.useCompressedOops", oops) - } - - def resetOrClear(prop: String, oldValue: String) { - if (oldValue != null) { - System.setProperty(prop, oldValue) - } else { - System.clearProperty(prop) - } } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f9d4bea823f7..4544382094f9 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite { +class UtilsSuite extends FunSuite with ResetSystemProperties { test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index adecd934358c..1b53f3edbe92 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -28,11 +28,9 @@ object BroadcastTest { val bcName = if (args.length > 2) args(2) else "Http" val blockSize = if (args.length > 3) args(3) else "4096" - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + - "BroadcastFactory") - System.setProperty("spark.broadcast.blockSize", blockSize) val sparkConf = new SparkConf().setAppName("Broadcast Test") - + .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory") + .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) val slices = if (args.length > 0) args(0).toInt else 2 diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071..1e24da7f5f60 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071..1e24da7f5f60 100644 --- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071..1e24da7f5f60 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071..1e24da7f5f60 100644 --- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 6e1f01900071..1e24da7f5f60 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -17,6 +17,7 @@ package org.apache.spark.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.junit.After; import org.junit.Before; @@ -27,8 +28,11 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { - System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); ssc.checkpoint("checkpoint"); } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index db58eb642b56..15ee95070a3d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util.Utils @@ -49,13 +49,13 @@ object StoragePerfTester { val writeData = "1" * recordLength val executor = Executors.newFixedThreadPool(numMaps) - System.setProperty("spark.shuffle.compress", "false") - System.setProperty("spark.shuffle.sync", "true") - System.setProperty("spark.shuffle.manager", - "org.apache.spark.shuffle.hash.HashShuffleManager") + val conf = new SparkConf() + .set("spark.shuffle.compress", "false") + .set("spark.shuffle.sync", "true") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") // This is only used to instantiate a BlockManager. All thread scheduling is done manually. - val sc = new SparkContext("local[4]", "Write Tester") + val sc = new SparkContext("local[4]", "Write Tester", conf) val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager] def writeOutputBytes(mapId: Int, total: AtomicLong) = { From 06a9aa589c518a40a3c7cc201e89d75af77ab93e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 31 Dec 2014 11:50:53 -0800 Subject: [PATCH 035/215] [SPARK-4797] Replace breezeSquaredDistance This PR replaces slow breezeSquaredDistance. Author: Liang-Chi Hsieh Closes #3643 from viirya/faster_squareddistance and squashes the following commits: f28b275 [Liang-Chi Hsieh] Move the implementation to linalg.Vectors and rename as sqdist. 0bc48ee [Liang-Chi Hsieh] Merge branch 'master' into faster_squareddistance ba34422 [Liang-Chi Hsieh] Fix bug. 91849d0 [Liang-Chi Hsieh] Modified for comment. 44a65ad [Liang-Chi Hsieh] Modified for comments. 35db395 [Liang-Chi Hsieh] Fix bug and some modifications for comments. f4f5ebb [Liang-Chi Hsieh] Follow BLAS.dot pattern to replace intersect, diff with while-loop. a36e09f [Liang-Chi Hsieh] Use while-loop to replace foreach for better performance. d3e0628 [Liang-Chi Hsieh] Make the methods private. dd415bc [Liang-Chi Hsieh] Consider different cases of SparseVector and DenseVector. 13669db [Liang-Chi Hsieh] Replace breezeSquaredDistance. --- .../apache/spark/mllib/linalg/Vectors.scala | 80 +++++++++++++++++++ .../org/apache/spark/mllib/util/MLUtils.scala | 13 ++- .../spark/mllib/util/MLUtilsSuite.scala | 15 ++++ 3 files changed, 100 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 01f3f9057714..6a782b079aac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -312,6 +312,86 @@ object Vectors { math.pow(sum, 1.0 / p) } } + + /** + * Returns the squared distance between two Vectors. + * @param v1 first Vector. + * @param v2 second Vector. + * @return squared distance between two Vectors. + */ + def sqdist(v1: Vector, v2: Vector): Double = { + var squaredDistance = 0.0 + (v1, v2) match { + case (v1: SparseVector, v2: SparseVector) => + val v1Values = v1.values + val v1Indices = v1.indices + val v2Values = v2.values + val v2Indices = v2.indices + val nnzv1 = v1Indices.size + val nnzv2 = v2Indices.size + + var kv1 = 0 + var kv2 = 0 + while (kv1 < nnzv1 || kv2 < nnzv2) { + var score = 0.0 + + if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) { + score = v1Values(kv1) + kv1 += 1 + } else if (kv1 >= nnzv1 || (kv2 < nnzv2 && v2Indices(kv2) < v1Indices(kv1))) { + score = v2Values(kv2) + kv2 += 1 + } else { + score = v1Values(kv1) - v2Values(kv2) + kv1 += 1 + kv2 += 1 + } + squaredDistance += score * score + } + + case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + squaredDistance = sqdist(v1, v2) + + case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + squaredDistance = sqdist(v2, v1) + + // When a SparseVector is approximately dense, we treat it as a DenseVector + case (v1, v2) => + squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => + val score = elems._1 - elems._2 + distance + score * score + } + } + squaredDistance + } + + /** + * Returns the squared distance between DenseVector and SparseVector. + */ + private[mllib] def sqdist(v1: SparseVector, v2: DenseVector): Double = { + var kv1 = 0 + var kv2 = 0 + val indices = v1.indices + var squaredDistance = 0.0 + var iv1 = indices(kv1) + val nnzv2 = v2.size + + while (kv2 < nnzv2) { + var score = 0.0 + if (kv2 != iv1) { + score = v2(kv2) + } else { + score = v1.values(kv1) - v2(kv2) + if (kv1 < indices.length - 1) { + kv1 += 1 + iv1 = indices(kv1) + } + } + squaredDistance += score * score + kv2 += 1 + } + squaredDistance + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index da0da0a168c1..c7843464a750 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, - squaredDistance => breezeSquaredDistance} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliCellSampler import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext @@ -266,7 +265,7 @@ object MLUtils { } Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: @@ -316,12 +315,10 @@ object MLUtils { val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) / (sqDist + EPSILON) if (precisionBound2 > precision) { - // TODO: breezeSquaredDistance is slow, - // so we should replace it with our own implementation. - sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze) + sqDist = Vectors.sqdist(v1, v2) } } else { - sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze) + sqDist = Vectors.sqdist(v1, v2) } sqDist } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index df07987093fb..7778847f8b72 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -52,12 +52,27 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val values = indices.map(i => a(i)) val v2 = Vectors.sparse(n, indices, values) val norm2 = Vectors.norm(v2, 2.0) + val v3 = Vectors.sparse(n, indices, indices.map(i => a(i) + 0.5)) + val norm3 = Vectors.norm(v3, 2.0) val squaredDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze) val fastSquaredDist1 = fastSquaredDistance(v1, norm1, v2, norm2, precision) assert((fastSquaredDist1 - squaredDist) <= precision * squaredDist, s"failed with m = $m") val fastSquaredDist2 = fastSquaredDistance(v1, norm1, Vectors.dense(v2.toArray), norm2, precision) assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m") + val squaredDist2 = breezeSquaredDistance(v2.toBreeze, v3.toBreeze) + val fastSquaredDist3 = + fastSquaredDistance(v2, norm2, v3, norm3, precision) + assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") + if (m > 10) { + val v4 = Vectors.sparse(n, indices.slice(0, m - 10), + indices.map(i => a(i) + 0.5).slice(0, m - 10)) + val norm4 = Vectors.norm(v4, 2.0) + val squaredDist = breezeSquaredDistance(v2.toBreeze, v4.toBreeze) + val fastSquaredDist = + fastSquaredDistance(v2, norm2, v4, norm4, precision) + assert((fastSquaredDist - squaredDist) <= precision * squaredDist, s"failed with m = $m") + } } } From 8e14c5eb551ab06c94859c7f6d8c6b62b4d00d59 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Wed, 31 Dec 2014 11:54:10 -0800 Subject: [PATCH 036/215] [SPARK-4298][Core] - The spark-submit cannot read Main-Class from Manifest. Resolves a bug where the `Main-Class` from a .jar file wasn't being read in properly. This was caused by the fact that the `primaryResource` object was a URI and needed to be normalized through a call to `.getPath` before it could be passed into the `JarFile` object. Author: Brennon York Closes #3561 from brennonyork/SPARK-4298 and squashes the following commits: 5e0fce1 [Brennon York] Use string interpolation for error messages, moved comment line from original code to above its necessary code segment 14daa20 [Brennon York] pushed mainClass assignment into match statement, removed spurious spaces, removed { } from case statements, removed return values c6dad68 [Brennon York] Set case statement to support multiple jar URI's and enabled the 'file' URI to load the main-class 8d20936 [Brennon York] updated to reset the error message back to the default a043039 [Brennon York] updated to split the uri and jar vals 8da7cbf [Brennon York] fixes SPARK-4298 --- .../spark/deploy/SparkSubmitArguments.scala | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f174bc1af59b..1faabe91f49a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import java.net.URI import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -125,14 +126,23 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && primaryResource != null) { - try { - val jar = new JarFile(primaryResource) - // Note that this might still return null if no main-class is set; we catch that later - mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class") - } catch { - case e: Exception => - SparkSubmit.printErrorAndExit("Cannot load main class from JAR: " + primaryResource) - return + val uri = new URI(primaryResource) + val uriScheme = uri.getScheme() + + uriScheme match { + case "file" => + try { + val jar = new JarFile(uri.getPath) + // Note that this might still return null if no main-class is set; we catch that later + mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class") + } catch { + case e: Exception => + SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + } + case _ => + SparkSubmit.printErrorAndExit( + s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + + "Please specify a class through --class.") } } From 3d194cc75761fceba77b2c91291b36479b8b556c Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 31 Dec 2014 13:37:04 -0800 Subject: [PATCH 037/215] SPARK-4547 [MLLIB] OOM when making bins in BinaryClassificationMetrics Now that I've implemented the basics here, I'm less convinced there is a need for this change, somehow. Callers can downsample before or after. Really the OOM is not in the ROC curve code, but in code that might `collect()` it for local analysis. Still, might be useful to down-sample since the ROC curve probably never needs millions of points. This is a first pass. Since the `(score,label)` are already grouped and sorted, I think it's sufficient to just take every Nth such pair, in order to downsample by a factor of N? this is just like retaining every Nth point on the curve, which I think is the goal. All of the data is still used to build the curve of course. What do you think about the API, and usefulness? Author: Sean Owen Closes #3702 from srowen/SPARK-4547 and squashes the following commits: 1d34d05 [Sean Owen] Indent and reorganize numBins scaladoc 692d825 [Sean Owen] Change handling of large numBins, make 2nd consturctor instead of optional param, style change a03610e [Sean Owen] Add downsamplingFactor to BinaryClassificationMetrics --- .../BinaryClassificationMetrics.scala | 59 ++++++++++++++++++- .../BinaryClassificationMetricsSuite.scala | 36 +++++++++++ 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 1af40de2c7fc..ced042e2f96c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -28,9 +28,30 @@ import org.apache.spark.rdd.{RDD, UnionRDD} * Evaluator for binary classification. * * @param scoreAndLabels an RDD of (score, label) pairs. + * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally + * will be down-sampled to this many "bins". If 0, no down-sampling will occur. + * This is useful because the curve contains a point for each distinct score + * in the input, and this could be as large as the input itself -- millions of + * points or more, when thousands may be entirely sufficient to summarize + * the curve. After down-sampling, the curves will instead be made of approximately + * `numBins` points instead. Points are made from bins of equal numbers of + * consecutive points. The size of each bin is + * `floor(scoreAndLabels.count() / numBins)`, which means the resulting number + * of bins may not exactly equal numBins. The last bin in each partition may + * be smaller as a result, meaning there may be an extra sample at + * partition boundaries. */ @Experimental -class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends Logging { +class BinaryClassificationMetrics( + val scoreAndLabels: RDD[(Double, Double)], + val numBins: Int) extends Logging { + + require(numBins >= 0, "numBins must be nonnegative") + + /** + * Defaults `numBins` to 0. + */ + def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) /** Unpersist intermediate RDDs used in the computation. */ def unpersist() { @@ -103,7 +124,39 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) - val agg = counts.values.mapPartitions { iter => + + val binnedCounts = + // Only down-sample if bins is > 0 + if (numBins == 0) { + // Use original directly + counts + } else { + val countsSize = counts.count() + // Group the iterator into chunks of about countsSize / numBins points, + // so that the resulting number of bins is about numBins + var grouping = countsSize / numBins + if (grouping < 2) { + // numBins was more than half of the size; no real point in down-sampling to bins + logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful") + counts + } else { + if (grouping >= Int.MaxValue) { + logWarning( + s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}") + grouping = Int.MaxValue + } + counts.mapPartitions(_.grouped(grouping.toInt).map { pairs => + // The score of the combined point will be just the first one's score + val firstScore = pairs.head._1 + // The point will contain all counts in this chunk + val agg = new BinaryLabelCounter() + pairs.foreach(pair => agg += pair._2) + (firstScore, agg) + }) + } + } + + val agg = binnedCounts.values.mapPartitions { iter => val agg = new BinaryLabelCounter() iter.foreach(agg += _) Iterator(agg) @@ -113,7 +166,7 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c) val totalCount = partitionwiseCumulativeCounts.last logInfo(s"Total counts: $totalCount") - val cumulativeCounts = counts.mapPartitionsWithIndex( + val cumulativeCounts = binnedCounts.mapPartitionsWithIndex( (index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => { val cumCount = partitionwiseCumulativeCounts(index) iter.map { case (score, c) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 8a18e2971cab..e0224f960cc4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -124,4 +124,40 @@ class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkConte validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } + + test("binary evaluation metrics with downsampling") { + val scoreAndLabels = Seq( + (0.1, 0.0), (0.2, 0.0), (0.3, 1.0), (0.4, 0.0), (0.5, 0.0), + (0.6, 1.0), (0.7, 1.0), (0.8, 0.0), (0.9, 1.0)) + + val scoreAndLabelsRDD = sc.parallelize(scoreAndLabels, 1) + + val original = new BinaryClassificationMetrics(scoreAndLabelsRDD) + val originalROC = original.roc().collect().sorted.toList + // Add 2 for (0,0) and (1,1) appended at either end + assert(2 + scoreAndLabels.size == originalROC.size) + assert( + List( + (0.0, 0.0), (0.0, 0.25), (0.2, 0.25), (0.2, 0.5), (0.2, 0.75), + (0.4, 0.75), (0.6, 0.75), (0.6, 1.0), (0.8, 1.0), (1.0, 1.0), + (1.0, 1.0) + ) == + originalROC) + + val numBins = 4 + + val downsampled = new BinaryClassificationMetrics(scoreAndLabelsRDD, numBins) + val downsampledROC = downsampled.roc().collect().sorted.toList + assert( + // May have to add 1 if the sample factor didn't divide evenly + 2 + (numBins + (if (scoreAndLabels.size % numBins == 0) 0 else 1)) == + downsampledROC.size) + assert( + List( + (0.0, 0.0), (0.2, 0.25), (0.2, 0.75), (0.6, 0.75), (0.8, 1.0), + (1.0, 1.0), (1.0, 1.0) + ) == + downsampledROC) + } + } From e24d3a9a29962023cc722896a14c7bfe06e8e601 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 Dec 2014 12:38:37 -0800 Subject: [PATCH 038/215] [HOTFIX] Disable Spark UI in SparkSubmitSuite tests This should fix a major cause of build breaks when running many parallel tests. --- .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 5eda2d41f0e6..065b7534cece 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -290,6 +290,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -304,6 +305,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, + "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } From c88a3d7fca20d36ee566d48e0cb91fe33a7a6d99 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 31 Dec 2014 14:25:03 -0800 Subject: [PATCH 039/215] [SPARK-5038][SQL] Add explicit return type for implicit functions in Spark SQL As we learned in https://github.com/apache/spark/pull/3580, not explicitly typing implicit functions can lead to compiler bugs and potentially unexpected runtime behavior. Author: Reynold Xin Closes #3859 from rxin/sql-implicits and squashes the following commits: 30c2c24 [Reynold Xin] [SPARK-5038] Add explicit return type for implicit functions in Spark SQL. --- .../spark/sql/catalyst/dsl/package.scala | 80 +++++++++---------- .../org/apache/spark/sql/SQLContext.scala | 2 +- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a14e5b9ef14d..8e39f79d2ca5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.types.decimal.Decimal - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -29,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * A collection of implicit conversions that create a DSL for constructing catalyst data structures. @@ -119,21 +118,22 @@ package object dsl { def expr = e } - implicit def booleanToLiteral(b: Boolean) = Literal(b) - implicit def byteToLiteral(b: Byte) = Literal(b) - implicit def shortToLiteral(s: Short) = Literal(s) - implicit def intToLiteral(i: Int) = Literal(i) - implicit def longToLiteral(l: Long) = Literal(l) - implicit def floatToLiteral(f: Float) = Literal(f) - implicit def doubleToLiteral(d: Double) = Literal(d) - implicit def stringToLiteral(s: String) = Literal(s) - implicit def dateToLiteral(d: Date) = Literal(d) - implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) - implicit def decimalToLiteral(d: Decimal) = Literal(d) - implicit def timestampToLiteral(t: Timestamp) = Literal(t) - implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) - - implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) + implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) + implicit def byteToLiteral(b: Byte): Literal = Literal(b) + implicit def shortToLiteral(s: Short): Literal = Literal(s) + implicit def intToLiteral(i: Int): Literal = Literal(i) + implicit def longToLiteral(l: Long): Literal = Literal(l) + implicit def floatToLiteral(f: Float): Literal = Literal(f) + implicit def doubleToLiteral(d: Double): Literal = Literal(d) + implicit def stringToLiteral(s: String): Literal = Literal(s) + implicit def dateToLiteral(d: Date): Literal = Literal(d) + implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d) + implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) + implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) + implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) + + implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = + analysis.UnresolvedAttribute(s.name) def sum(e: Expression) = Sum(e) def sumDistinct(e: Expression) = SumDistinct(e) @@ -301,52 +301,52 @@ package object dsl { (1 to 22).map { x => val argTypes = Seq.fill(x)("_").mkString(", ") - s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func)" } */ - implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7a1330222901..6a1a4d995bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -106,7 +106,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { SparkPlan.currentContext.set(self) val attributeSeq = ScalaReflection.attributesFor[A] val schema = StructType.fromAttributes(attributeSeq) From 3610d3c615112faef98d94f04efaea602cc4aa8f Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Wed, 31 Dec 2014 14:35:07 -0800 Subject: [PATCH 040/215] [SPARK-4790][STREAMING] Fix ReceivedBlockTrackerSuite waits for old file... ...s to get deleted before continuing. Since the deletes are happening asynchronously, the getFileStatus call might throw an exception in older HDFS versions, if the delete happens between the time listFiles is called on the directory and getFileStatus is called on the file in the getFileStatus method. This PR addresses this by adding an option to delete the files synchronously and then waiting for the deletion to complete before proceeding. Author: Hari Shreedharan Closes #3726 from harishreedharan/spark-4790 and squashes the following commits: bbbacd1 [Hari Shreedharan] Call cleanUpOldLogs only once in the tests. 3255f17 [Hari Shreedharan] Add test for async deletion. Remove method from ReceiverTracker that does not take waitForCompletion. e4c83ec [Hari Shreedharan] Making waitForCompletion a mandatory param. Remove eventually from WALSuite since the cleanup method returns only after all files are deleted. af00fd1 [Hari Shreedharan] [SPARK-4790][STREAMING] Fix ReceivedBlockTrackerSuite waits for old files to get deleted before continuing. --- .../receiver/ReceivedBlockHandler.scala | 8 ++++---- .../scheduler/ReceivedBlockTracker.scala | 9 ++++++--- .../streaming/scheduler/ReceiverTracker.scala | 2 +- .../streaming/util/WriteAheadLogManager.scala | 17 +++++++++++++---- .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- .../streaming/ReceivedBlockTrackerSuite.scala | 2 +- .../streaming/util/WriteAheadLogSuite.scala | 18 ++++++++++++++++-- 7 files changed, 42 insertions(+), 16 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 8b97db8dd36f..f7a8ebee8a54 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -42,7 +42,7 @@ private[streaming] trait ReceivedBlockHandler { def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult /** Cleanup old blocks older than the given threshold time */ - def cleanupOldBlock(threshTime: Long) + def cleanupOldBlocks(threshTime: Long) } @@ -82,7 +82,7 @@ private[streaming] class BlockManagerBasedBlockHandler( BlockManagerBasedStoreResult(blockId) } - def cleanupOldBlock(threshTime: Long) { + def cleanupOldBlocks(threshTime: Long) { // this is not used as blocks inserted into the BlockManager are cleared by DStream's clearing // of BlockRDDs. } @@ -192,8 +192,8 @@ private[streaming] class WriteAheadLogBasedBlockHandler( WriteAheadLogBasedStoreResult(blockId, segment) } - def cleanupOldBlock(threshTime: Long) { - logManager.cleanupOldLogs(threshTime) + def cleanupOldBlocks(threshTime: Long) { + logManager.cleanupOldLogs(threshTime, waitForCompletion = false) } def stop() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 02758e0bca6c..2ce458cddec1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -139,14 +139,17 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(streamId).toSeq } - /** Clean up block information of old batches. */ - def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized { + /** + * Clean up block information of old batches. If waitForCompletion is true, this method + * returns only after the files are cleaned up. + */ + def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { assert(cleanupThreshTime.milliseconds < clock.currentTime()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup - logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds)) + logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) log } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 1f0e442a1228..8dbb42a86e3b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -121,7 +121,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Clean up metadata older than the given threshold time */ def cleanupOldMetadata(cleanupThreshTime: Time) { - receivedBlockTracker.cleanupOldBatches(cleanupThreshTime) + receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) } /** Register a receiver */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala index 70d234320be7..166661b7496d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala @@ -19,11 +19,11 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.Logging import org.apache.spark.util.Utils import WriteAheadLogManager._ @@ -124,8 +124,12 @@ private[streaming] class WriteAheadLogManager( * files, which is usually based on the local system time. So if there is coordination necessary * between the node calculating the threshTime (say, driver node), and the local system time * (say, worker node), the caller has to take account of possible time skew. + * + * If waitForCompletion is set to true, this method will return only after old logs have been + * deleted. This should be set to true only for testing. Else the files will be deleted + * asynchronously. */ - def cleanupOldLogs(threshTime: Long): Unit = { + def cleanupOldLogs(threshTime: Long, waitForCompletion: Boolean): Unit = { val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") @@ -146,10 +150,15 @@ private[streaming] class WriteAheadLogManager( logInfo(s"Cleared log files in $logDirectory older than $threshTime") } if (!executionContext.isShutdown) { - Future { deleteFiles() } + val f = Future { deleteFiles() } + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } } } + /** Stop the manager, close any open log writer */ def stop(): Unit = synchronized { if (currentLogWriter != null) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 3661e16a9ef2..132ff2443fc0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -168,7 +168,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche manualClock.currentTime() shouldEqual 5000L val cleanupThreshTime = 3000L - handler.cleanupOldBlock(cleanupThreshTime) + handler.cleanupOldBlocks(cleanupThreshTime) eventually(timeout(10000 millis), interval(10 millis)) { getWriteAheadLogFiles().size should be < preCleanupLogFiles.size } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 01a09b67b99d..de7e9d624bf6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -166,7 +166,7 @@ class ReceivedBlockTrackerSuite // Cleanup first batch but not second batch val oldestLogFile = getWriteAheadLogFiles().head incrementTime() - tracker3.cleanupOldBatches(batchTime2) + tracker3.cleanupOldBatches(batchTime2, waitForCompletion = true) // Verify that the batch allocations have been cleaned, and the act has been written to log tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8f69bcb64279..7ce9499dc614 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -182,15 +182,29 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } test("WriteAheadLogManager - cleanup old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test("WriteAheadLogManager - cleanup old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { // Write data with manager, recover with new manager and verify val manualClock = new ManualClock val dataToWrite = generateRandomData() manager = writeDataUsingManager(testDir, dataToWrite, manualClock, stopManager = false) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.currentTime() / 2) - eventually(timeout(1 second), interval(10 milliseconds)) { + + manager.cleanupOldLogs(manualClock.currentTime() / 2, waitForCompletion) + + if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } } } From fdc2aa4918fd4c510f04812b782cc0bfef9a2107 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Dec 2014 14:45:31 -0800 Subject: [PATCH 041/215] [SPARK-5028][Streaming]Add total received and processed records metrics to Streaming UI This is a follow-up work of [SPARK-4537](https://issues.apache.org/jira/browse/SPARK-4537). Adding total received records and processed records metrics back to UI. ![screenshot](https://dl.dropboxusercontent.com/u/19230832/screenshot.png) Author: jerryshao Closes #3852 from jerryshao/SPARK-5028 and squashes the following commits: c8c4877 [jerryshao] Add total received and processed metrics to Streaming UI --- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 1353e487c72c..98e9a2e639e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -67,6 +67,12 @@ private[ui] class StreamingPage(parent: StreamingTab)
  • Waiting batches: {listener.numUnprocessedBatches}
  • +
  • + Received records: {listener.numTotalReceivedRecords} +
  • +
  • + Processed records: {listener.numTotalProcessedRecords} +
  • } From c4f0b4f334f7f3565375921fcac184ad5b1fb207 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Wed, 31 Dec 2014 15:39:58 -0800 Subject: [PATCH 042/215] SPARK-5020 [MLlib] GaussianMixtureModel.predictMembership() should take an RDD only Removed unnecessary parameters to predictMembership() CC: jkbradley Author: Travis Galoppo Closes #3854 from tgaloppo/spark-5020 and squashes the following commits: 1bf4669 [Travis Galoppo] renamed predictMembership() to predictSoft() 0f1d96e [Travis Galoppo] SPARK-5020 - Removed superfluous parameters from predictMembership() --- .../spark/mllib/clustering/GaussianMixtureModel.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 11a110db1f7c..b461ea4f0f06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -45,7 +45,7 @@ class GaussianMixtureModel( /** Maps given points to their cluster indices. */ def predict(points: RDD[Vector]): RDD[Int] = { - val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k) + val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } @@ -53,12 +53,7 @@ class GaussianMixtureModel( * Given the input vectors, return the membership value of each vector * to all mixture components. */ - def predictMembership( - points: RDD[Vector], - mu: Array[Vector], - sigma: Array[Matrix], - weight: Array[Double], - k: Int): RDD[Array[Double]] = { + def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val dists = sc.broadcast { (0 until k).map { i => From fe6efacc0b865e9e827a1565877077000e63976e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 31 Dec 2014 16:02:47 -0800 Subject: [PATCH 043/215] [SPARK-5035] [Streaming] ReceiverMessage trait should extend Serializable Spark Streaming's ReceiverMessage trait should extend Serializable in order to fix a subtle bug that only occurs when running on a real cluster: If you attempt to send a fire-and-forget message to a remote Akka actor and that message cannot be serialized, then this seems to lead to more-or-less silent failures. As an optimization, Akka skips message serialization for messages sent within the same JVM. As a result, Spark's unit tests will never fail due to non-serializable Akka messages, but these will cause mostly-silent failures when running on a real cluster. Before this patch, here was the code for ReceiverMessage: ``` /** Messages sent to the NetworkReceiver. */ private[streaming] sealed trait ReceiverMessage private[streaming] object StopReceiver extends ReceiverMessage ``` Since ReceiverMessage does not extend Serializable and StopReceiver is a regular `object`, not a `case object`, StopReceiver will throw serialization errors. As a result, graceful receiver shutdown is broken on real clusters (and local-cluster mode) but works in local modes. If you want to reproduce this, try running the word count example from the Streaming Programming Guide in the Spark shell: ``` import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ val ssc = new StreamingContext(sc, Seconds(10)) // Create a DStream that will connect to hostname:port, like localhost:9999 val lines = ssc.socketTextStream("localhost", 9999) // Split each line into words val words = lines.flatMap(_.split(" ")) import org.apache.spark.streaming.StreamingContext._ // Count each word in each batch val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) // Print the first ten elements of each RDD generated in this DStream to the console wordCounts.print() ssc.start() Thread.sleep(10000) ssc.stop(true, true) ``` Prior to this patch, this would work correctly in local mode but fail when running against a real cluster (it would report that some receivers were not shut down). Author: Josh Rosen Closes #3857 from JoshRosen/SPARK-5035 and squashes the following commits: 71d0eae [Josh Rosen] [SPARK-5035] ReceiverMessage trait should extend Serializable. --- .../org/apache/spark/streaming/receiver/ReceiverMessage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index bf39d1e891ca..ab9fa192191a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -18,6 +18,6 @@ package org.apache.spark.streaming.receiver /** Messages sent to the NetworkReceiver. */ -private[streaming] sealed trait ReceiverMessage +private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage From 4bb12488d56ea651c56d9688996b464b99095582 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 31 Dec 2014 16:59:17 -0800 Subject: [PATCH 044/215] SPARK-2757 [BUILD] [STREAMING] Add Mima test for Spark Sink after 1.10 is released Re-enable MiMa for Streaming Flume Sink module, now that 1.1.0 is released, per the JIRA TO-DO. That's pretty much all there is to this. Author: Sean Owen Closes #3842 from srowen/SPARK-2757 and squashes the following commits: 50ff80e [Sean Owen] Exclude apparent false positive turned up by re-enabling MiMa checks for Streaming Flume Sink 0e5ba5c [Sean Owen] Re-enable MiMa for Streaming Flume Sink module --- project/MimaExcludes.scala | 5 +++++ project/SparkBuild.scala | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 230239aa4050..c377e5cffa7d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -53,6 +53,11 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-2757 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + + "removeAndGetProcessor") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c512b62f6137..46a54c681840 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -166,7 +166,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } From 7749dd6c36a182478b20f4636734c8db0b7ddb00 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 31 Dec 2014 17:07:47 -0800 Subject: [PATCH 045/215] [SPARK-5038] Add explicit return type for implicit functions. As we learned in #3580, not explicitly typing implicit functions can lead to compiler bugs and potentially unexpected runtime behavior. This is a follow up PR for rest of Spark (outside Spark SQL). The original PR for Spark SQL can be found at https://github.com/apache/spark/pull/3859 Author: Reynold Xin Closes #3860 from rxin/implicit and squashes the following commits: 73702f9 [Reynold Xin] [SPARK-5038] Add explicit return type for implicit functions. --- .../scala/org/apache/spark/SparkContext.scala | 14 ++--- .../scala/org/apache/spark/util/Vector.scala | 38 +++++------ .../graphx/impl/EdgePartitionBuilder.scala | 63 ++++++++++--------- .../impl/ShippableVertexPartition.scala | 4 +- .../spark/graphx/impl/VertexPartition.scala | 4 +- .../graphx/impl/VertexPartitionBaseOps.scala | 4 +- 6 files changed, 64 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 57bc3d4e4ae3..df1cb3cda2db 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1708,19 +1708,19 @@ object SparkContext extends Logging { // Implicit conversions to common Writable types, for saveAsSequenceFile - implicit def intToIntWritable(i: Int) = new IntWritable(i) + implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) - implicit def longToLongWritable(l: Long) = new LongWritable(l) + implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) - implicit def floatToFloatWritable(f: Float) = new FloatWritable(f) + implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) - implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d) + implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) - implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b) + implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) - implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob) + implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) - implicit def stringToText(s: String) = new Text(s) + implicit def stringToText(s: String): Text = new Text(s) private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) : ArrayWritable = { diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index c6cab82c3e54..2ed827eab46d 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -24,9 +24,9 @@ import org.apache.spark.util.random.XORShiftRandom @deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") class Vector(val elements: Array[Double]) extends Serializable { - def length = elements.length + def length: Int = elements.length - def apply(index: Int) = elements(index) + def apply(index: Int): Double = elements(index) def + (other: Vector): Vector = { if (length != other.length) { @@ -35,7 +35,7 @@ class Vector(val elements: Array[Double]) extends Serializable { Vector(length, i => this(i) + other(i)) } - def add(other: Vector) = this + other + def add(other: Vector): Vector = this + other def - (other: Vector): Vector = { if (length != other.length) { @@ -44,7 +44,7 @@ class Vector(val elements: Array[Double]) extends Serializable { Vector(length, i => this(i) - other(i)) } - def subtract(other: Vector) = this - other + def subtract(other: Vector): Vector = this - other def dot(other: Vector): Double = { if (length != other.length) { @@ -93,19 +93,19 @@ class Vector(val elements: Array[Double]) extends Serializable { this } - def addInPlace(other: Vector) = this +=other + def addInPlace(other: Vector): Vector = this +=other def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - def multiply (d: Double) = this * d + def multiply (d: Double): Vector = this * d def / (d: Double): Vector = this * (1 / d) - def divide (d: Double) = this / d + def divide (d: Double): Vector = this / d - def unary_- = this * -1 + def unary_- : Vector = this * -1 - def sum = elements.reduceLeft(_ + _) + def sum: Double = elements.reduceLeft(_ + _) def squaredDist(other: Vector): Double = { var ans = 0.0 @@ -119,40 +119,40 @@ class Vector(val elements: Array[Double]) extends Serializable { def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - override def toString = elements.mkString("(", ", ", ")") + override def toString: String = elements.mkString("(", ", ", ")") } object Vector { - def apply(elements: Array[Double]) = new Vector(elements) + def apply(elements: Array[Double]): Vector = new Vector(elements) - def apply(elements: Double*) = new Vector(elements.toArray) + def apply(elements: Double*): Vector = new Vector(elements.toArray) def apply(length: Int, initializer: Int => Double): Vector = { val elements: Array[Double] = Array.tabulate(length)(initializer) new Vector(elements) } - def zeros(length: Int) = new Vector(new Array[Double](length)) + def zeros(length: Int): Vector = new Vector(new Array[Double](length)) - def ones(length: Int) = Vector(length, _ => 1) + def ones(length: Int): Vector = Vector(length, _ => 1) /** * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. */ - def random(length: Int, random: Random = new XORShiftRandom()) = + def random(length: Int, random: Random = new XORShiftRandom()): Vector = Vector(length, _ => random.nextDouble()) class Multiplier(num: Double) { - def * (vec: Vector) = vec * num + def * (vec: Vector): Vector = vec * num } - implicit def doubleToMultiplier(num: Double) = new Multiplier(num) + implicit def doubleToMultiplier(num: Double): Multiplier = new Multiplier(num) implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector) = t1 + t2 + def addInPlace(t1: Vector, t2: Vector): Vector = t1 + t2 - def zero(initialValue: Vector) = Vector.zeros(initialValue.length) + def zero(initialValue: Vector): Vector = Vector.zeros(initialValue.length) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 409cf60977f6..906d42328fcb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -129,44 +129,45 @@ private[impl] case class EdgeWithLocalIds[@specialized ED]( srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) private[impl] object EdgeWithLocalIds { - implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { - override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { - if (a.srcId == b.srcId) { - if (a.dstId == b.dstId) 0 - else if (a.dstId < b.dstId) -1 + implicit def lexicographicOrdering[ED]: Ordering[EdgeWithLocalIds[ED]] = + new Ordering[EdgeWithLocalIds[ED]] { + override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { + if (a.srcId == b.srcId) { + if (a.dstId == b.dstId) 0 + else if (a.dstId < b.dstId) -1 + else 1 + } else if (a.srcId < b.srcId) -1 else 1 - } else if (a.srcId < b.srcId) -1 - else 1 + } } - } - private[graphx] def edgeArraySortDataFormat[ED] - = new SortDataFormat[EdgeWithLocalIds[ED], Array[EdgeWithLocalIds[ED]]] { - override def getKey( - data: Array[EdgeWithLocalIds[ED]], pos: Int): EdgeWithLocalIds[ED] = { - data(pos) - } + private[graphx] def edgeArraySortDataFormat[ED] = { + new SortDataFormat[EdgeWithLocalIds[ED], Array[EdgeWithLocalIds[ED]]] { + override def getKey(data: Array[EdgeWithLocalIds[ED]], pos: Int): EdgeWithLocalIds[ED] = { + data(pos) + } - override def swap(data: Array[EdgeWithLocalIds[ED]], pos0: Int, pos1: Int): Unit = { - val tmp = data(pos0) - data(pos0) = data(pos1) - data(pos1) = tmp - } + override def swap(data: Array[EdgeWithLocalIds[ED]], pos0: Int, pos1: Int): Unit = { + val tmp = data(pos0) + data(pos0) = data(pos1) + data(pos1) = tmp + } - override def copyElement( - src: Array[EdgeWithLocalIds[ED]], srcPos: Int, - dst: Array[EdgeWithLocalIds[ED]], dstPos: Int) { - dst(dstPos) = src(srcPos) - } + override def copyElement( + src: Array[EdgeWithLocalIds[ED]], srcPos: Int, + dst: Array[EdgeWithLocalIds[ED]], dstPos: Int) { + dst(dstPos) = src(srcPos) + } - override def copyRange( - src: Array[EdgeWithLocalIds[ED]], srcPos: Int, - dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int) { - System.arraycopy(src, srcPos, dst, dstPos, length) - } + override def copyRange( + src: Array[EdgeWithLocalIds[ED]], srcPos: Int, + dst: Array[EdgeWithLocalIds[ED]], dstPos: Int, length: Int) { + System.arraycopy(src, srcPos, dst, dstPos, length) + } - override def allocate(length: Int): Array[EdgeWithLocalIds[ED]] = { - new Array[EdgeWithLocalIds[ED]](length) + override def allocate(length: Int): Array[EdgeWithLocalIds[ED]] = { + new Array[EdgeWithLocalIds[ED]](length) + } } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index 5412d720475d..aa320088f208 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -74,8 +74,8 @@ object ShippableVertexPartition { * Implicit conversion to allow invoking `VertexPartitionBase` operations directly on a * `ShippableVertexPartition`. */ - implicit def shippablePartitionToOps[VD: ClassTag](partition: ShippableVertexPartition[VD]) = - new ShippableVertexPartitionOps(partition) + implicit def shippablePartitionToOps[VD: ClassTag](partition: ShippableVertexPartition[VD]) + : ShippableVertexPartitionOps[VD] = new ShippableVertexPartitionOps(partition) /** * Implicit evidence that `ShippableVertexPartition` is a member of the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index 55c7a19d1bda..fbe53acfc32a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -38,8 +38,8 @@ private[graphx] object VertexPartition { * Implicit conversion to allow invoking `VertexPartitionBase` operations directly on a * `VertexPartition`. */ - implicit def partitionToOps[VD: ClassTag](partition: VertexPartition[VD]) = - new VertexPartitionOps(partition) + implicit def partitionToOps[VD: ClassTag](partition: VertexPartition[VD]) + : VertexPartitionOps[VD] = new VertexPartitionOps(partition) /** * Implicit evidence that `VertexPartition` is a member of the `VertexPartitionBaseOpsConstructor` diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index b40aa1b417a0..4fd2548b7faf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -238,8 +238,8 @@ private[graphx] abstract class VertexPartitionBaseOps * because these methods return a `Self` and this implicit conversion re-wraps that in a * `VertexPartitionBaseOps`. This relies on the context bound on `Self`. */ - private implicit def toOps[VD2: ClassTag]( - partition: Self[VD2]): VertexPartitionBaseOps[VD2, Self] = { + private implicit def toOps[VD2: ClassTag](partition: Self[VD2]) + : VertexPartitionBaseOps[VD2, Self] = { implicitly[VertexPartitionBaseOpsConstructor[Self]].toOps(partition) } } From 012839807c3dc6e7c8c41ac6e956d52a550bb031 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 1 Jan 2015 15:03:54 -0800 Subject: [PATCH 046/215] [HOTFIX] Bind web UI to ephemeral port in DriverSuite The job launched by DriverSuite should bind the web UI to an ephemeral port, since it looks like port contention in this test has caused a large number of Jenkins failures when many builds are started simultaneously. Our tests already disable the web UI, but this doesn't affect subprocesses launched by our tests. In this case, I've opted to bind to an ephemeral port instead of disabling the UI because disabling features in this test may mask its ability to catch certain bugs. See also: e24d3a9 Author: Josh Rosen Closes #3873 from JoshRosen/driversuite-webui-port and squashes the following commits: 48cd05c [Josh Rosen] [HOTFIX] Bind web UI to ephemeral port in DriverSuite. --- core/src/test/scala/org/apache/spark/DriverSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 5265ba904032..541d8eac8055 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -50,7 +50,10 @@ class DriverSuite extends FunSuite with Timeouts { object DriverWithoutCleanup { def main(args: Array[String]) { Utils.configTestLog4j("INFO") - val sc = new SparkContext(args(0), "DriverWithoutCleanup") + // Bind the web UI to an ephemeral port in order to avoid conflicts with other tests running on + // the same machine (we shouldn't just disable the UI here, since that might mask bugs): + val conf = new SparkConf().set("spark.ui.port", "0") + val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() } } From bd88b7185358ae60efc83dc6cbb3fb1d2bff6074 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Fri, 2 Jan 2015 15:09:41 -0800 Subject: [PATCH 047/215] [SPARK-3325][Streaming] Add a parameter to the method print in class DStream This PR is a fixed version of the original PR #3237 by watermen and scwf. This adds the ability to specify how many elements to print in `DStream.print`. Author: Yadong Qi Author: q00251598 Author: Tathagata Das Author: wangfei Closes #3865 from tdas/print-num and squashes the following commits: cd34e9e [Tathagata Das] Fix bug 7c09f16 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into HEAD bb35d1a [Yadong Qi] Update MimaExcludes.scala f8098ca [Yadong Qi] Update MimaExcludes.scala f6ac3cb [Yadong Qi] Update MimaExcludes.scala e4ed897 [Yadong Qi] Update MimaExcludes.scala 3b9d5cf [wangfei] fix conflicts ec8a3af [q00251598] move to Spark 1.3 26a70c0 [q00251598] extend the Python DStream's print b589a4b [q00251598] add another print function --- project/MimaExcludes.scala | 3 +++ python/pyspark/streaming/dstream.py | 12 +++++++----- .../spark/streaming/api/java/JavaDStreamLike.scala | 10 +++++++++- .../apache/spark/streaming/dstream/DStream.scala | 14 +++++++++++--- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c377e5cffa7d..31d4c317ae56 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,6 +54,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") ) ++ Seq( + // SPARK-3325 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), // SPARK-2757 ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 0826ddc56e84..2fe39392ff08 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -157,18 +157,20 @@ def foreachRDD(self, func): api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) - def pprint(self): + def pprint(self, num=10): """ - Print the first ten elements of each RDD generated in this DStream. + Print the first num elements of each RDD generated in this DStream. + + @param num: the number of elements from the first will be printed. """ def takeAndPrint(time, rdd): - taken = rdd.take(11) + taken = rdd.take(num + 1) print "-------------------------------------------" print "Time: %s" % time print "-------------------------------------------" - for record in taken[:10]: + for record in taken[:num]: print record - if len(taken) > 10: + if len(taken) > num: print "..." print diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 2a7004e56ef5..e0542eda1383 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -51,7 +51,15 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * operator, so this DStream will be registered as an output stream and there materialized. */ def print(): Unit = { - dstream.print() + print(10) + } + + /** + * Print the first num elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ + def print(num: Int): Unit = { + dstream.print(num) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 7f8651e719d8..28fc00cf3944 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -605,13 +605,21 @@ abstract class DStream[T: ClassTag] ( * operator, so this DStream will be registered as an output stream and there materialized. */ def print() { + print(10) + } + + /** + * Print the first num elements of each RDD generated in this DStream. This is an output + * operator, so this DStream will be registered as an output stream and there materialized. + */ + def print(num: Int) { def foreachFunc = (rdd: RDD[T], time: Time) => { - val first11 = rdd.take(11) + val firstNum = rdd.take(num + 1) println ("-------------------------------------------") println ("Time: " + time) println ("-------------------------------------------") - first11.take(10).foreach(println) - if (first11.size > 10) println("...") + firstNum.take(num).foreach(println) + if (firstNum.size > num) println("...") println() } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() From cdccc263b20c1bb27b864411c82cfad7daca1f47 Mon Sep 17 00:00:00 2001 From: Akhil Das Date: Fri, 2 Jan 2015 15:12:27 -0800 Subject: [PATCH 048/215] Fixed typos in streaming-kafka-integration.md Changed projrect to project :) Author: Akhil Das Closes #3876 from akhld/patch-1 and squashes the following commits: e0cf9ef [Akhil Das] Fixed typos in streaming-kafka-integration.md --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 1c956fcb40da..4378521dcac7 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -4,7 +4,7 @@ title: Spark Streaming + Kafka Integration Guide --- [Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. -1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). +1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). groupId = org.apache.spark artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} From 342612b65f3d77c660383a332f0346872f076647 Mon Sep 17 00:00:00 2001 From: sigmoidanalytics Date: Sat, 3 Jan 2015 19:46:08 -0800 Subject: [PATCH 049/215] [SPARK-5058] Updated broken links Updated the broken link pointing to the KafkaWordCount example to the correct one. Author: sigmoidanalytics Closes #3877 from sigmoidanalytics/patch-1 and squashes the following commits: 3e19b31 [sigmoidanalytics] Updated broken links --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 4378521dcac7..0e38fe2144e9 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -20,7 +20,7 @@ title: Spark Streaming + Kafka Integration Guide streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
    import org.apache.spark.streaming.kafka.*; From b96008d5529bac5fd57b76554fd01760139cffff Mon Sep 17 00:00:00 2001 From: Brennon York Date: Sun, 4 Jan 2015 12:40:39 -0800 Subject: [PATCH 050/215] [SPARK-794][Core] Remove sleep() in ClusterScheduler.stop Removed `sleep()` from the `stop()` method of the `TaskSchedulerImpl` class which, from the JIRA ticket, is believed to be a legacy artifact slowing down testing originally introduced in the `ClusterScheduler` class. Author: Brennon York Closes #3851 from brennonyork/SPARK-794 and squashes the following commits: 04c3e64 [Brennon York] Removed sleep() from the stop() method --- .../scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index cd3c015321e8..a41f3eef195d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -394,9 +394,6 @@ private[spark] class TaskSchedulerImpl( taskResultGetter.stop() } starvationTimer.cancel() - - // sleeping for an arbitrary 1 seconds to ensure that messages are sent out. - Thread.sleep(1000L) } override def defaultParallelism() = backend.defaultParallelism() From 3fddc9468fa50e7683caa973fec6c52e1132268d Mon Sep 17 00:00:00 2001 From: Dale Date: Sun, 4 Jan 2015 13:28:37 -0800 Subject: [PATCH 051/215] [SPARK-4787] Stop SparkContext if a DAGScheduler init error occurs Author: Dale Closes #3809 from tigerquoll/SPARK-4787 and squashes the following commits: 5661e01 [Dale] [SPARK-4787] Ensure that call to stop() doesn't lose the exception by using a finally block. 2172578 [Dale] [SPARK-4787] Stop context properly if an exception occurs during DAGScheduler initialization. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index df1cb3cda2db..4c25d5d6c0ce 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -329,8 +329,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli try { dagScheduler = new DAGScheduler(this) } catch { - case e: Exception => throw - new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage)) + case e: Exception => { + try { + stop() + } finally { + throw new SparkException("Error while constructing DAGScheduler", e) + } + } } // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's From e767d7ddac5c2330af553f2a74b8575dfc7afb67 Mon Sep 17 00:00:00 2001 From: bilna Date: Sun, 4 Jan 2015 19:37:48 -0800 Subject: [PATCH 052/215] [SPARK-4631] unit test for MQTT Please review the unit test for MQTT Author: bilna Author: Bilna P Closes #3844 from Bilna/master and squashes the following commits: acea3a3 [bilna] Adding dependency with scope test 28681fa [bilna] Merge remote-tracking branch 'upstream/master' fac3904 [bilna] Correction in Indentation and coding style ed9db4c [bilna] Merge remote-tracking branch 'upstream/master' 4b34ee7 [Bilna P] Update MQTTStreamSuite.scala 04503cf [bilna] Added embedded broker service for mqtt test 89d804e [bilna] Merge remote-tracking branch 'upstream/master' fc8eb28 [bilna] Merge remote-tracking branch 'upstream/master' 4b58094 [Bilna P] Update MQTTStreamSuite.scala b1ac4ad [bilna] Added BeforeAndAfter 5f6bfd2 [bilna] Added BeforeAndAfter e8b6623 [Bilna P] Update MQTTStreamSuite.scala 5ca6691 [Bilna P] Update MQTTStreamSuite.scala 8616495 [bilna] [SPARK-4631] unit test for MQTT --- external/mqtt/pom.xml | 6 + .../streaming/mqtt/MQTTStreamSuite.scala | 110 +++++++++++++++--- 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 9025915f4447..d478267b605b 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -66,6 +66,12 @@ junit-interface test + + org.apache.activemq + activemq-core + 5.7.0 + test + target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 84595acf45cc..98fe6cb301f5 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,31 +17,111 @@ package org.apache.spark.streaming.mqtt -import org.scalatest.FunSuite +import java.net.{URI, ServerSocket} -import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.activemq.broker.{TransportConnector, BrokerService} +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually +import scala.concurrent.duration._ +import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -class MQTTStreamSuite extends FunSuite { - - val batchDuration = Seconds(1) +class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { + private val batchDuration = Milliseconds(500) private val master: String = "local[2]" - private val framework: String = this.getClass.getSimpleName + private val freePort = findFreePort() + private val brokerUri = "//localhost:" + freePort + private val topic = "def" + private var ssc: StreamingContext = _ + private val persistenceDir = Utils.createTempDir() + private var broker: BrokerService = _ + private var connector: TransportConnector = _ - test("mqtt input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val brokerUrl = "abc" - val topic = "def" + before { + ssc = new StreamingContext(master, framework, batchDuration) + setupMQTT() + } - // tests the API, does not actually test data receiving - val test1: ReceiverInputDStream[String] = MQTTUtils.createStream(ssc, brokerUrl, topic) - val test2: ReceiverInputDStream[String] = - MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_AND_DISK_SER_2) + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + Utils.deleteRecursively(persistenceDir) + tearDownMQTT() + } - // TODO: Actually test receiving data + test("mqtt input stream") { + val sendMessage = "MQTT demo for spark streaming" + val receiveStream: ReceiverInputDStream[String] = + MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + var receiveMessage: List[String] = List() + receiveStream.foreachRDD { rdd => + if (rdd.collect.length > 0) { + receiveMessage = receiveMessage ::: List(rdd.first) + receiveMessage + } + } + ssc.start() + publishData(sendMessage) + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + assert(sendMessage.equals(receiveMessage(0))) + } ssc.stop() } + + private def setupMQTT() { + broker = new BrokerService() + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt:" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + private def tearDownMQTT() { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + } + + private def findFreePort(): Int = { + Utils.startServiceOnPort(23456, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + })._2 + } + + def publishData(data: String): Unit = { + var client: MqttClient = null + try { + val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic: MqttTopic = client.getTopic(topic) + val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) + message.setQos(1) + message.setRetained(true) + for (i <- 0 to 100) + msgTopic.publish(message) + } + } finally { + client.disconnect() + client.close() + client = null + } + } } From 939ba1f8f6e32fef9026cc43fce55b36e4b9bfd1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 4 Jan 2015 20:26:18 -0800 Subject: [PATCH 053/215] [SPARK-4835] Disable validateOutputSpecs for Spark Streaming jobs This patch disables output spec. validation for jobs launched through Spark Streaming, since this interferes with checkpoint recovery. Hadoop OutputFormats have a `checkOutputSpecs` method which performs certain checks prior to writing output, such as checking whether the output directory already exists. SPARK-1100 added checks for FileOutputFormat, SPARK-1677 (#947) added a SparkConf configuration to disable these checks, and SPARK-2309 (#1088) extended these checks to run for all OutputFormats, not just FileOutputFormat. In Spark Streaming, we might have to re-process a batch during checkpoint recovery, so `save` actions may be called multiple times. In addition to `DStream`'s own save actions, users might use `transform` or `foreachRDD` and call the `RDD` and `PairRDD` save actions. When output spec. validation is enabled, the second calls to these actions will fail due to existing output. This patch automatically disables output spec. validation for jobs submitted by the Spark Streaming scheduler. This is done by using Scala's `DynamicVariable` to propagate the bypass setting without having to mutate SparkConf or introduce a global variable. Author: Josh Rosen Closes #3832 from JoshRosen/SPARK-4835 and squashes the following commits: 36eaf35 [Josh Rosen] Add comment explaining use of transform() in test. 6485cf8 [Josh Rosen] Add test case in Streaming; fix bug for transform() 7b3e06a [Josh Rosen] Remove Streaming-specific setting to undo this change; update conf. guide bf9094d [Josh Rosen] Revise disableOutputSpecValidation() comment to not refer to Spark Streaming. e581d17 [Josh Rosen] Deduplicate isOutputSpecValidationEnabled logic. 762e473 [Josh Rosen] [SPARK-4835] Disable validateOutputSpecs for Spark Streaming jobs. --- .../apache/spark/rdd/PairRDDFunctions.scala | 19 ++++++++- docs/configuration.md | 4 +- .../spark/streaming/dstream/DStream.scala | 10 ++++- .../dstream/TransformedDStream.scala | 2 +- .../streaming/scheduler/JobScheduler.scala | 8 +++- .../spark/streaming/CheckpointSuite.scala | 39 +++++++++++++++++++ 6 files changed, 75 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 4469c89e6bb1..f8df5b2a0886 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -25,6 +25,7 @@ import scala.collection.{Map, mutable} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} @@ -964,7 +965,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { + if (isOutputSpecValidationEnabled) { // FileOutputFormat ignores the filesystem parameter jobFormat.checkOutputSpecs(job) } @@ -1042,7 +1043,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) { + if (isOutputSpecValidationEnabled) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(hadoopConf) hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) @@ -1117,8 +1118,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def valueClass: Class[_] = vt.runtimeClass private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + private def isOutputSpecValidationEnabled: Boolean = { + val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value + val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } } private[spark] object PairRDDFunctions { val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) } diff --git a/docs/configuration.md b/docs/configuration.md index fa9d311f8506..9bb649999373 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -709,7 +709,9 @@ Apart from these, the following properties are also available, and may be useful If set to true, validates the output specification (e.g. checking if the output directory already exists) used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing output directories. We recommend that users do not disable this except if trying to achieve compatibility with - previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since + data may need to be rewritten to pre-existing output directories during checkpoint recovery. spark.hadoop.cloneConf diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 28fc00cf3944..b874f561c12e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import scala.util.matching.Regex import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.{BlockRDD, RDD} +import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName @@ -292,7 +292,13 @@ abstract class DStream[T: ClassTag] ( // set this DStream's creation site, generate RDDs and then restore the previous call site. val prevCallSite = ssc.sparkContext.getCallSite() ssc.sparkContext.setCallSite(creationSite) - val rddOption = compute(time) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. We need to have this call here because + // compute() might cause Spark jobs to be launched. + val rddOption = PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + compute(time) + } ssc.sparkContext.setCallSite(prevCallSite) rddOption.foreach { case newRDD => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 7cd4554282ca..71b61856e23c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index cfa3cd8925c8..0e0f5bd3b9db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} import akka.actor.{ActorRef, Actor, Props} import org.apache.spark.{SparkException, Logging, SparkEnv} +import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ @@ -168,7 +169,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private class JobHandler(job: Job) extends Runnable { def run() { eventActor ! JobStarted(job) - job.run() + // Disable checks for existing output directories in jobs launched by the streaming scheduler, + // since we may need to write output to an existing directory during checkpoint recovery; + // see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } eventActor ! JobCompleted(job) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 72d055eb2ea3..5d232c6ade7a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -255,6 +255,45 @@ class CheckpointSuite extends TestSuiteBase { } } + test("recovery with saveAsHadoopFile inside transform operation") { + // Regression test for SPARK-4835. + // + // In that issue, the problem was that `saveAsHadoopFile(s)` would fail when the last batch + // was restarted from a checkpoint since the output directory would already exist. However, + // the other saveAsHadoopFile* tests couldn't catch this because they only tested whether the + // output matched correctly and not whether the post-restart batch had successfully finished + // without throwing any errors. The following test reproduces the same bug with a test that + // actually fails because the error in saveAsHadoopFile causes transform() to fail, which + // prevents the expected output from being written to the output stream. + // + // This is not actually a valid use of transform, but it's being used here so that we can test + // the fix for SPARK-4835 independently of additional test cleanup. + // + // After SPARK-5079 is addressed, should be able to remove this test since a strengthened + // version of the other saveAsHadoopFile* tests would prevent regressions for this issue. + val tempDir = Files.createTempDir() + try { + testCheckpointedOperation( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => { + s.transform { (rdd, time) => + val output = rdd.map(x => (x, 1)).reduceByKey(_ + _) + output.saveAsHadoopFile( + new File(tempDir, "result-" + time.milliseconds).getAbsolutePath, + classOf[Text], + classOf[IntWritable], + classOf[TextOutputFormat[Text, IntWritable]]) + output + } + }, + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + 3 + ) + } finally { + Utils.deleteRecursively(tempDir) + } + } + // This tests whether the StateDStream's RDD checkpoints works correctly such // that the system can recover from a master failure. This assumes as reliable, // replayable input source - TestInputDStream. From 72396522bcf5303f761956658510672e4feb2845 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 21:03:17 -0800 Subject: [PATCH 054/215] [SPARK-5067][Core] Use '===' to compare well-defined case class A simple fix would be adding `assert(e1.appId == e2.appId)` for `SparkListenerApplicationStart`. But actually we can use `===` for well-defined case class directly. Therefore, instead of fixing this issue, I use `===` to compare those well-defined case classes (all fields have implemented a correct `equals` method, such as primitive types) Author: zsxwing Closes #3886 from zsxwing/SPARK-5067 and squashes the following commits: 0a51711 [zsxwing] Use '===' to compare well-defined case class --- .../apache/spark/util/JsonProtocolSuite.scala | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 593d6dd8c379..63c2559c5c5f 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -280,7 +280,7 @@ class JsonProtocolSuite extends FunSuite { private def testBlockManagerId(id: BlockManagerId) { val newId = JsonProtocol.blockManagerIdFromJson(JsonProtocol.blockManagerIdToJson(id)) - assertEquals(id, newId) + assert(id === newId) } private def testTaskInfo(info: TaskInfo) { @@ -335,22 +335,8 @@ class JsonProtocolSuite extends FunSuite { assertEquals(e1.jobResult, e2.jobResult) case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) - case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) => - assert(e1.maxMem === e2.maxMem) - assert(e1.time === e2.time) - assertEquals(e1.blockManagerId, e2.blockManagerId) - case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) => - assert(e1.time === e2.time) - assertEquals(e1.blockManagerId, e2.blockManagerId) - case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) => - assert(e1.rddId == e2.rddId) - case (e1: SparkListenerApplicationStart, e2: SparkListenerApplicationStart) => - assert(e1.appName == e2.appName) - assert(e1.time == e2.time) - assert(e1.sparkUser == e2.sparkUser) - case (e1: SparkListenerApplicationEnd, e2: SparkListenerApplicationEnd) => - assert(e1.time == e2.time) - case (SparkListenerShutdown, SparkListenerShutdown) => + case (e1, e2) => + assert(e1 === e2) case _ => fail("Events don't match in types!") } } @@ -435,16 +421,6 @@ class JsonProtocolSuite extends FunSuite { assert(metrics1.bytesRead === metrics2.bytesRead) } - private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) { - if (bm1 == null || bm2 == null) { - assert(bm1 === bm2) - } else { - assert(bm1.executorId === bm2.executorId) - assert(bm1.host === bm2.host) - assert(bm1.port === bm2.port) - } - } - private def assertEquals(result1: JobResult, result2: JobResult) { (result1, result2) match { case (JobSucceeded, JobSucceeded) => @@ -462,7 +438,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.shuffleId === r2.shuffleId) assert(r1.mapId === r2.mapId) assert(r1.reduceId === r2.reduceId) - assertEquals(r1.bmAddress, r2.bmAddress) + assert(r1.bmAddress === r2.bmAddress) assert(r1.message === r2.message) case (r1: ExceptionFailure, r2: ExceptionFailure) => assert(r1.className === r2.className) From 6c726a3fbd9cd3aa5f3a1992b2132b25eabb76a0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 21:06:04 -0800 Subject: [PATCH 055/215] [SPARK-5069][Core] Fix the race condition of TaskSchedulerImpl.dagScheduler It's not necessary to set `TaskSchedulerImpl.dagScheduler` in preStart. It's safe to set it after `initializeEventProcessActor()`. Author: zsxwing Closes #3887 from zsxwing/SPARK-5069 and squashes the following commits: d95894f [zsxwing] Fix the race condition of TaskSchedulerImpl.dagScheduler --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 7 +------ .../apache/spark/scheduler/TaskSchedulerImplSuite.scala | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cb8ccfbdbdcb..259621d263d7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -138,6 +138,7 @@ class DAGScheduler( } initializeEventProcessActor() + taskScheduler.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { @@ -1375,12 +1376,6 @@ private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler) extends Actor with Logging { - override def preStart() { - // set DAGScheduler for taskScheduler to ensure eventProcessActor is always - // valid when the messages arrive - dagScheduler.taskScheduler.setDAGScheduler(dagScheduler) - } - /** * The main event loop of the DAG scheduler. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 40aaf9dd1f1e..00812e6018d1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -305,7 +305,6 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} override def executorAdded(execId: String, host: String) {} } - taskScheduler.setDAGScheduler(dagScheduler) // Give zero core offers. Should not generate any tasks val zeroCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", 0), new WorkerOffer("executor1", "host1", 0)) From 27e7f5a7237d9d64a3b2c8a030ba3e3a9a96b26c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 21:09:21 -0800 Subject: [PATCH 056/215] [SPARK-5083][Core] Fix a flaky test in TaskResultGetterSuite Because `sparkEnv.blockManager.master.removeBlock` is asynchronous, we need to make sure the block has already been removed before calling `super.enqueueSuccessfulTask`. Author: zsxwing Closes #3894 from zsxwing/SPARK-5083 and squashes the following commits: d97c03d [zsxwing] Fix a flaky test in TaskResultGetterSuite --- .../scheduler/TaskResultGetterSuite.scala | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 3aab5a156ee7..e3a3803e6483 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -19,7 +19,12 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.control.NonFatal + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} import org.apache.spark.storage.TaskResultBlockId @@ -34,6 +39,8 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false + @volatile var removeBlockSuccessfully = false + override def enqueueSuccessfulTask( taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) { if (!removedResult) { @@ -42,6 +49,15 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule serializer.get().deserialize[TaskResult[_]](serializedData) match { case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) + // removeBlock is asynchronous. Need to wait it's removed successfully + try { + eventually(timeout(3 seconds), interval(200 milliseconds)) { + assert(!sparkEnv.blockManager.master.contains(blockId)) + } + removeBlockSuccessfully = true + } catch { + case NonFatal(e) => removeBlockSuccessfully = false + } case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") } @@ -92,10 +108,12 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark assert(false, "Expect local cluster to use TaskSchedulerImpl") throw new ClassCastException } - scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) + val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) + scheduler.taskResultGetter = resultGetter val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) + assert(resultGetter.removeBlockSuccessfully) assert(result === 1.to(akkaFrameSize).toArray) // Make sure two tasks were run (one failed one, and a second retried one). From 5c506cecb933b156b2f06a688ee08c4347bf0d47 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 4 Jan 2015 21:18:33 -0800 Subject: [PATCH 057/215] [SPARK-5074][Core] Fix a non-deterministic test failure Add `assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))` to make sure `sparkListener` receive the message. Author: zsxwing Closes #3889 from zsxwing/SPARK-5074 and squashes the following commits: e61c198 [zsxwing] Fix a non-deterministic test failure --- .../scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d6ec9e129cce..d30eb10bbe94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -247,6 +247,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F test("[SPARK-3353] parent stage should have lower stage id") { sparkListener.stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.stageByOrderOfExecution.length === 2) assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) } From d3f07fd23cc26a70f44c52e24445974d4885d58a Mon Sep 17 00:00:00 2001 From: Varun Saxena Date: Mon, 5 Jan 2015 10:32:37 -0800 Subject: [PATCH 058/215] [SPARK-4688] Have a single shared network timeout in Spark [SPARK-4688] Have a single shared network timeout in Spark Author: Varun Saxena Author: varunsaxena Closes #3562 from varunsaxena/SPARK-4688 and squashes the following commits: 6e97f72 [Varun Saxena] [SPARK-4688] Single shared network timeout cd783a2 [Varun Saxena] SPARK-4688 d6f8c29 [Varun Saxena] SCALA-4688 9562b15 [Varun Saxena] SPARK-4688 a75f014 [varunsaxena] SPARK-4688 594226c [varunsaxena] SPARK-4688 --- .../apache/spark/network/nio/ConnectionManager.scala | 3 ++- .../apache/spark/storage/BlockManagerMasterActor.scala | 7 +++++-- .../main/scala/org/apache/spark/util/AkkaUtils.scala | 2 +- docs/configuration.md | 10 ++++++++++ .../org/apache/spark/network/util/TransportConf.java | 4 +++- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 243b71c98086..98455c096826 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -81,7 +81,8 @@ private[nio] class ConnectionManager( private val ackTimeoutMonitor = new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) - private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) + private val ackTimeout = + conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 100)) // Get the thread counts from the Spark Configuration. // diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 9cbda41223a8..9d77cf27882e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,8 +52,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", - math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) + val slaveTimeout = { + val defaultMs = math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000) + val networkTimeout = conf.getInt("spark.network.timeout", defaultMs / 1000) + conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", networkTimeout * 1000) + } val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 8c2457f56bff..64e3a5416c6b 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -65,7 +65,7 @@ private[spark] object AkkaUtils extends Logging { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", 100) + val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 100)) val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" diff --git a/docs/configuration.md b/docs/configuration.md index 9bb649999373..7ada67fc303c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -818,6 +818,16 @@ Apart from these, the following properties are also available, and may be useful Communication timeout between Spark nodes, in seconds. + + spark.network.timeout + 100 + + Default timeout for all network interactions, in seconds. This config will be used in + place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, + spark.storage.blockManagerSlaveTimeoutMs or spark.shuffle.io.connectionTimeout, + if they are not configured. + + spark.akka.heartbeat.pauses 6000 diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 7c9adf52af0f..e34382da22a5 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -37,7 +37,9 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; + int timeout = + conf.getInt("spark.shuffle.io.connectionTimeout", conf.getInt("spark.network.timeout", 100)); + return timeout * 1000; } /** Number of concurrent connections between two nodes for fetching data. */ From ce39b34404868de4ca51be06832169187b1aef7d Mon Sep 17 00:00:00 2001 From: WangTao Date: Mon, 5 Jan 2015 11:59:38 -0800 Subject: [PATCH 059/215] [SPARK-5057] Log message in failed askWithReply attempts https://issues.apache.org/jira/browse/SPARK-5057 Author: WangTao Author: WangTaoTheTonic Closes #3875 from WangTaoTheTonic/SPARK-5057 and squashes the following commits: 1503487 [WangTao] use string interpolation 706c8a7 [WangTaoTheTonic] log more messages --- .../scala/org/apache/spark/util/AkkaUtils.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 64e3a5416c6b..8d86fd3e11ad 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -89,7 +89,7 @@ private[spark] object AkkaUtils extends Logging { } val requireCookie = if (isAuthOn) "on" else "off" val secureCookie = if (isAuthOn) secretKey else "" - logDebug("In createActorSystem, requireCookie is: " + requireCookie) + logDebug(s"In createActorSystem, requireCookie is: $requireCookie") val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( @@ -140,8 +140,8 @@ private[spark] object AkkaUtils extends Logging { def maxFrameSizeBytes(conf: SparkConf): Int = { val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { - throw new IllegalArgumentException("spark.akka.frameSize should not be greater than " - + AKKA_MAX_FRAME_SIZE_IN_MB + "MB") + throw new IllegalArgumentException( + s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") } frameSizeInMB * 1024 * 1024 } @@ -182,8 +182,8 @@ private[spark] object AkkaUtils extends Logging { timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { - throw new SparkException("Error sending message as actor is null " + - "[message = " + message + "]") + throw new SparkException(s"Error sending message [message = $message]" + + " as actor is null ") } var attempts = 0 var lastException: Exception = null @@ -200,13 +200,13 @@ private[spark] object AkkaUtils extends Logging { case ie: InterruptedException => throw ie case e: Exception => lastException = e - logWarning("Error sending message in " + attempts + " attempts", e) + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } Thread.sleep(retryInterval) } throw new SparkException( - "Error sending message [message = " + message + "]", lastException) + s"Error sending message [message = $message]", lastException) } def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { From 1c0e7ce056c79e1db96f85b8c56a479b8b043970 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Mon, 5 Jan 2015 12:05:09 -0800 Subject: [PATCH 060/215] [SPARK-4465] runAsSparkUser doesn't affect TaskRunner in Mesos environme... ...nt at all. - fixed a scope of runAsSparkUser from MesosExecutorDriver.run to MesosExecutorBackend.launchTask - See the Jira Issue for more details. Author: Jongyoul Lee Closes #3741 from jongyoul/SPARK-4465 and squashes the following commits: 46ad71e [Jongyoul Lee] [SPARK-4465] runAsSparkUser doesn't affect TaskRunner in Mesos environment at all. - Removed unused import 3d6631f [Jongyoul Lee] [SPARK-4465] runAsSparkUser doesn't affect TaskRunner in Mesos environment at all. - Removed comments and adjusted indentations 2343f13 [Jongyoul Lee] [SPARK-4465] runAsSparkUser doesn't affect TaskRunner in Mesos environment at all. - fixed a scope of runAsSparkUser from MesosExecutorDriver.run to MesosExecutorBackend.launchTask --- .../spark/executor/MesosExecutorBackend.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index a098d07bd865..2e23ae0a4f83 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary} +import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} @@ -80,7 +80,9 @@ private[spark] class MesosExecutorBackend if (executor == null) { logError("Received launchTask but executor was null") } else { - executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) + SparkHadoopUtil.get.runAsSparkUser { () => + executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) + } } } @@ -112,11 +114,8 @@ private[spark] class MesosExecutorBackend private[spark] object MesosExecutorBackend extends Logging { def main(args: Array[String]) { SignalLogger.register(log) - SparkHadoopUtil.get.runAsSparkUser { () => - MesosNativeLibrary.load() - // Create a new Executor and start it running - val runner = new MesosExecutorBackend() - new MesosExecutorDriver(runner).run() - } + // Create a new Executor and start it running + val runner = new MesosExecutorBackend() + new MesosExecutorDriver(runner).run() } } From 6c6f32574023b8e43a24f2081ff17e6e446de2f3 Mon Sep 17 00:00:00 2001 From: freeman Date: Mon, 5 Jan 2015 13:10:59 -0800 Subject: [PATCH 061/215] [SPARK-5089][PYSPARK][MLLIB] Fix vector convert This is a small change addressing a potentially significant bug in how PySpark + MLlib handles non-float64 numpy arrays. The automatic conversion to `DenseVector` that occurs when passing RDDs to MLlib algorithms in PySpark should automatically upcast to float64s, but currently this wasn't actually happening. As a result, non-float64 would be silently parsed inappropriately during SerDe, yielding erroneous results when running, for example, KMeans. The PR includes the fix, as well as a new test for the correct conversion behavior. davies Author: freeman Closes #3902 from freeman-lab/fix-vector-convert and squashes the following commits: 764db47 [freeman] Add a test for proper conversion behavior 704f97e [freeman] Return array after changing type --- python/pyspark/mllib/linalg.py | 2 +- python/pyspark/mllib/tests.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f7aa2b0cb04b..4f8491f43e45 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -178,7 +178,7 @@ def __init__(self, ar): elif not isinstance(ar, np.ndarray): ar = np.array(ar, dtype=np.float64) if ar.dtype != np.float64: - ar.astype(np.float64) + ar = ar.astype(np.float64) self.array = ar def __reduce__(self): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5034f229e824..1f48bc1219db 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -110,6 +110,16 @@ def test_squared_distance(self): self.assertEquals(0.0, _squared_distance(dv, dv)) self.assertEquals(0.0, _squared_distance(lst, lst)) + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + class ListTests(PySparkTestCase): From bbcba3a9430365640c0188e7ca6e0677d3227dd8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 5 Jan 2015 15:19:53 -0800 Subject: [PATCH 062/215] [SPARK-5093] Set spark.network.timeout to 120s consistently. Author: Reynold Xin Closes #3903 from rxin/timeout-120 and squashes the following commits: 7c2138e [Reynold Xin] [SPARK-5093] Set spark.network.timeout to 120s consistently. --- .../org/apache/spark/network/nio/ConnectionManager.scala | 2 +- .../org/apache/spark/storage/BlockManagerMasterActor.scala | 6 +----- core/src/main/scala/org/apache/spark/util/AkkaUtils.scala | 2 +- docs/configuration.md | 6 +++--- .../java/org/apache/spark/network/util/TransportConf.java | 5 ++--- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 98455c096826..3340fca08014 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -82,7 +82,7 @@ private[nio] class ConnectionManager( new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = - conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 100)) + conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120)) // Get the thread counts from the Spark Configuration. // diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 9d77cf27882e..64133464d8da 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -52,11 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = { - val defaultMs = math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000) - val networkTimeout = conf.getInt("spark.network.timeout", defaultMs / 1000) - conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", networkTimeout * 1000) - } + val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000) val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 8d86fd3e11ad..db2531dc171f 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -65,7 +65,7 @@ private[spark] object AkkaUtils extends Logging { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 100)) + val akkaTimeout = conf.getInt("spark.akka.timeout", conf.getInt("spark.network.timeout", 120)) val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" diff --git a/docs/configuration.md b/docs/configuration.md index 7ada67fc303c..2add48569bec 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -820,12 +820,12 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout - 100 + 120 Default timeout for all network interactions, in seconds. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or spark.shuffle.io.connectionTimeout, - if they are not configured. + spark.storage.blockManagerSlaveTimeoutMs or + spark.shuffle.io.connectionTimeout, if they are not configured. diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index e34382da22a5..6c9178688693 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -37,9 +37,8 @@ public boolean preferDirectBufs() { /** Connect timeout in milliseconds. Default 120 secs. */ public int connectionTimeoutMs() { - int timeout = - conf.getInt("spark.shuffle.io.connectionTimeout", conf.getInt("spark.network.timeout", 100)); - return timeout * 1000; + int defaultTimeout = conf.getInt("spark.network.timeout", 120); + return conf.getInt("spark.shuffle.io.connectionTimeout", defaultTimeout) * 1000; } /** Number of concurrent connections between two nodes for fetching data. */ From 04d55d8e8e4890d110ce5561b5c1ae608c34a7c9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 5 Jan 2015 15:34:22 -0800 Subject: [PATCH 063/215] [SPARK-5040][SQL] Support expressing unresolved attributes using $"attribute name" notation in SQL DSL. Author: Reynold Xin Closes #3862 from rxin/stringcontext-attr and squashes the following commits: 9b10f57 [Reynold Xin] Rename StrongToAttributeConversionHelper 72121af [Reynold Xin] [SPARK-5040][SQL] Support expressing unresolved attributes using $"attribute name" notation in SQL DSL. --- .../org/apache/spark/sql/catalyst/dsl/package.scala | 9 +++++++++ .../scala/org/apache/spark/sql/DslQuerySuite.scala | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8e39f79d2ca5..9608e15c0f30 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -135,6 +135,15 @@ package object dsl { implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = analysis.UnresolvedAttribute(s.name) + /** Converts $"col name" into an [[analysis.UnresolvedAttribute]]. */ + implicit class StringToAttributeConversionHelper(val sc: StringContext) { + // Note that if we make ExpressionConversions an object rather than a trait, we can + // then make this a value class to avoid the small penalty of runtime instantiation. + def $(args: Any*): analysis.UnresolvedAttribute = { + analysis.UnresolvedAttribute(sc.s(args :_*)) + } + } + def sum(e: Expression) = Sum(e) def sumDistinct(e: Expression) = SumDistinct(e) def count(e: Expression) = Count(e) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index c0b9cf516312..ab88f3ad10d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -56,6 +56,18 @@ class DslQuerySuite extends QueryTest { ) } + test("convert $\"attribute name\" into unresolved attribute") { + checkAnswer( + testData.where($"key" === 1).select($"value"), + Seq(Seq("1"))) + } + + test("convert Scala Symbol 'attrname into unresolved attribute") { + checkAnswer( + testData.where('key === 1).select('value), + Seq(Seq("1"))) + } + test("select *") { checkAnswer( testData.select(Star(None)), From 451546aa6d2e61e43b0c0f0669f18cfb7489e584 Mon Sep 17 00:00:00 2001 From: Kostas Sakellis Date: Mon, 5 Jan 2015 23:26:33 -0800 Subject: [PATCH 064/215] SPARK-4843 [YARN] Squash ExecutorRunnableUtil and ExecutorRunnable ExecutorRunnableUtil is a parent of ExecutorRunnable because of the yarn-alpha and yarn-stable split. Now that yarn-alpha is gone, this commit squashes the unnecessary hierarchy. The methods from ExecutorRunnableUtil are added as private. Author: Kostas Sakellis Closes #3696 from ksakellis/kostas-spark-4843 and squashes the following commits: 486716f [Kostas Sakellis] Moved prepareEnvironment call to after yarnConf declaration 470e22e [Kostas Sakellis] Fixed indentation and renamed sparkConf variable 9b1b1c9 [Kostas Sakellis] SPARK-4843 [YARN] Squash ExecutorRunnableUtil and ExecutorRunnable --- .../spark/deploy/yarn/ExecutorRunnable.scala | 182 +++++++++++++++- .../deploy/yarn/ExecutorRunnableUtil.scala | 203 ------------------ 2 files changed, 172 insertions(+), 213 deletions(-) delete mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index fdd3c2300fa7..6d9198c122e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -17,32 +17,33 @@ package org.apache.spark.deploy.yarn +import java.net.URI import java.nio.ByteBuffer -import java.security.PrivilegedExceptionAction + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.spark.util.Utils import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer -import org.apache.hadoop.net.NetUtils import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils -import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.client.api.NMClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.network.util.JavaUtils - class ExecutorRunnable( container: Container, conf: Configuration, - spConf: SparkConf, + sparkConf: SparkConf, masterAddress: String, slaveId: String, hostname: String, @@ -50,13 +51,13 @@ class ExecutorRunnable( executorCores: Int, appId: String, securityMgr: SecurityManager) - extends Runnable with ExecutorRunnableUtil with Logging { + extends Runnable with Logging { var rpc: YarnRPC = YarnRPC.create(conf) var nmClient: NMClient = _ - val sparkConf = spConf val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - + lazy val env = prepareEnvironment + def run = { logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() @@ -110,4 +111,165 @@ class ExecutorRunnable( nmClient.startContainer(container, ctx) } + private def prepareCommand( + masterAddress: String, + slaveId: String, + hostname: String, + executorMemory: Int, + executorCores: Int, + appId: String, + localResources: HashMap[String, LocalResource]): List[String] = { + // Extra options for the JVM + val javaOpts = ListBuffer[String]() + + // Set the environment variable through a command prefix + // to append to the existing value of the variable + var prefixEnv: Option[String] = None + + // Set the JVM memory + val executorMemoryString = executorMemory + "m" + javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " + + // Set extra Java options for the executor, if defined + sys.props.get("spark.executor.extraJavaOptions").foreach { opts => + javaOpts += opts + } + sys.env.get("SPARK_JAVA_OPTS").foreach { opts => + javaOpts += opts + } + sys.props.get("spark.executor.extraLibraryPath").foreach { p => + prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + } + + javaOpts += "-Djava.io.tmpdir=" + + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + + // Certain configs need to be passed here because they are needed before the Executor + // registers with the Scheduler and transfers the spark configs. Since the Executor backend + // uses Akka to connect to the scheduler, the akka settings are needed as well as the + // authentication settings. + sparkConf.getAll. + filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + + sparkConf.getAkkaConf. + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + + // Commenting it out for now - so that people can refer to the properties if required. Remove + // it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence + // if there are multiple containers in same node, spark gc effects all other containers + // performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in + // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset + // of cores on a node. + /* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in + // spark.executor.extraJavaOptions, so we dont want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in + // multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use + // %20the%20Concurrent%20Low%20Pause%20Collector|outline + javaOpts += " -XX:+UseConcMarkSweepGC " + javaOpts += " -XX:+CMSIncrementalMode " + javaOpts += " -XX:+CMSIncrementalPacing " + javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " + javaOpts += " -XX:CMSIncrementalDutyCycle=10 " + } + */ + + // For log4j configuration to reference + javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + + val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", + "-server", + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in + // an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do + // 'something' to fail job ... akin to blacklisting trackers in mapred ? + "-XX:OnOutOfMemoryError='kill %p'") ++ + javaOpts ++ + Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + masterAddress.toString, + slaveId.toString, + hostname.toString, + executorCores.toString, + appId, + "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", + "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + + // TODO: it would be nicer to just make sure there are no null commands here + commands.map(s => if (s == null) "null" else s).toList + } + + private def setupDistributedCache( + file: String, + rtype: LocalResourceType, + localResources: HashMap[String, LocalResource], + timestamp: String, + size: String, + vis: String): Unit = { + val uri = new URI(file) + val amJarRsrc = Records.newRecord(classOf[LocalResource]) + amJarRsrc.setType(rtype) + amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) + amJarRsrc.setTimestamp(timestamp.toLong) + amJarRsrc.setSize(size.toLong) + localResources(uri.getFragment()) = amJarRsrc + } + + private def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val localResources = HashMap[String, LocalResource]() + + if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') + for( i <- 0 to distFiles.length - 1) { + setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), + fileSizes(i), visibilities(i)) + } + } + + if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') + val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') + for( i <- 0 to distArchives.length - 1) { + setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, + timeStamps(i), fileSizes(i), visibilities(i)) + } + } + + logInfo("Prepared Local resources " + localResources) + localResources + } + + private def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + val extraCp = sparkConf.getOption("spark.executor.extraClassPath") + ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) + + sparkConf.getExecutorEnv.foreach { case (key, value) => + // This assumes each executor environment variable set here is a path + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } + + // Keep this for backwards compatibility but users should move to the config + sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => + YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) + } + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + env + } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala deleted file mode 100644 index 22d73ecf6d01..000000000000 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn - -import java.net.URI - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, ListBuffer} - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.{ConverterUtils, Records} - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.util.Utils - -trait ExecutorRunnableUtil extends Logging { - - val yarnConf: YarnConfiguration - val sparkConf: SparkConf - lazy val env = prepareEnvironment - - def prepareCommand( - masterAddress: String, - slaveId: String, - hostname: String, - executorMemory: Int, - executorCores: Int, - appId: String, - localResources: HashMap[String, LocalResource]): List[String] = { - // Extra options for the JVM - val javaOpts = ListBuffer[String]() - - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - - // Set the JVM memory - val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " - - // Set extra Java options for the executor, if defined - sys.props.get("spark.executor.extraJavaOptions").foreach { opts => - javaOpts += opts - } - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts += opts - } - sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) - } - - javaOpts += "-Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) - - // Certain configs need to be passed here because they are needed before the Executor - // registers with the Scheduler and transfers the spark configs. Since the Executor backend - // uses Akka to connect to the scheduler, the akka settings are needed as well as the - // authentication settings. - sparkConf.getAll. - filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } - - sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } - - // Commenting it out for now - so that people can refer to the properties if required. Remove - // it once cpuset version is pushed out. - // The context is, default gc for server class machines end up using all cores to do gc - hence - // if there are multiple containers in same node, spark gc effects all other containers - // performance (which can also be other spark containers) - // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in - // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset - // of cores on a node. - /* - else { - // If no java_opts specified, default to using -XX:+CMSIncrementalMode - // It might be possible that other modes/config is being done in - // spark.executor.extraJavaOptions, so we dont want to mess with it. - // In our expts, using (default) throughput collector has severe perf ramnifications in - // multi-tennent machines - // The options are based on - // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use - // %20the%20Concurrent%20Low%20Pause%20Collector|outline - javaOpts += " -XX:+UseConcMarkSweepGC " - javaOpts += " -XX:+CMSIncrementalMode " - javaOpts += " -XX:+CMSIncrementalPacing " - javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 " - javaOpts += " -XX:CMSIncrementalDutyCycle=10 " - } - */ - - // For log4j configuration to reference - javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) - - val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", - "-server", - // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. - // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in - // an inconsistent state. - // TODO: If the OOM is not recoverable by rescheduling it on different node, then do - // 'something' to fail job ... akin to blacklisting trackers in mapred ? - "-XX:OnOutOfMemoryError='kill %p'") ++ - javaOpts ++ - Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", - masterAddress.toString, - slaveId.toString, - hostname.toString, - executorCores.toString, - appId, - "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", - "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - - // TODO: it would be nicer to just make sure there are no null commands here - commands.map(s => if (s == null) "null" else s).toList - } - - private def setupDistributedCache( - file: String, - rtype: LocalResourceType, - localResources: HashMap[String, LocalResource], - timestamp: String, - size: String, - vis: String): Unit = { - val uri = new URI(file) - val amJarRsrc = Records.newRecord(classOf[LocalResource]) - amJarRsrc.setType(rtype) - amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) - amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) - amJarRsrc.setTimestamp(timestamp.toLong) - amJarRsrc.setSize(size.toLong) - localResources(uri.getFragment()) = amJarRsrc - } - - def prepareLocalResources: HashMap[String, LocalResource] = { - logInfo("Preparing Local resources") - val localResources = HashMap[String, LocalResource]() - - if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { - val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') - val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') - val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') - for( i <- 0 to distFiles.length - 1) { - setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), - fileSizes(i), visibilities(i)) - } - } - - if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) { - val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') - val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') - val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') - val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') - for( i <- 0 to distArchives.length - 1) { - setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, - timeStamps(i), fileSizes(i), visibilities(i)) - } - } - - logInfo("Prepared Local resources " + localResources) - localResources - } - - def prepareEnvironment: HashMap[String, String] = { - val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) - - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - } - - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } - env - } - -} From a6394bc2c094c6c662237236c2effa2dabe67910 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 6 Jan 2015 00:31:19 -0800 Subject: [PATCH 065/215] [SPARK-1600] Refactor FileInputStream tests to remove Thread.sleep() calls and SystemClock usage This patch refactors Spark Streaming's FileInputStream tests to remove uses of Thread.sleep() and SystemClock, which should hopefully resolve some longstanding flakiness in these tests (see SPARK-1600). Key changes: - Modify FileInputDStream to use the scheduler's Clock instead of System.currentTimeMillis(); this allows it to be tested using ManualClock. - Fix a synchronization issue in ManualClock's `currentTime` method. - Add a StreamingTestWaiter class which allows callers to block until a certain number of batches have finished. - Change the FileInputStream tests so that files' modification times are manually set based off of ManualClock; this eliminates many Thread.sleep calls. - Update these tests to use the withStreamingContext fixture. Author: Josh Rosen Closes #3801 from JoshRosen/SPARK-1600 and squashes the following commits: e4494f4 [Josh Rosen] Address a potential race when setting file modification times 8340bd0 [Josh Rosen] Use set comparisons for output. 0b9c252 [Josh Rosen] Fix some ManualClock usage problems. 1cc689f [Josh Rosen] ConcurrentHashMap -> SynchronizedMap db26c3a [Josh Rosen] Use standard timeout in ScalaTest `eventually` blocks. 3939432 [Josh Rosen] Rename StreamingTestWaiter to BatchCounter 0b9c3a1 [Josh Rosen] Wait for checkpoint to complete 863d71a [Josh Rosen] Remove Thread.sleep that was used to make task run slowly b4442c3 [Josh Rosen] batchTimeToSelectedFiles should be thread-safe 15b48ee [Josh Rosen] Replace several TestWaiter methods w/ ScalaTest eventually. fffc51c [Josh Rosen] Revert "Remove last remaining sleep() call" dbb8247 [Josh Rosen] Remove last remaining sleep() call 566a63f [Josh Rosen] Fix log message and comment typos da32f3f [Josh Rosen] Fix log message and comment typos 3689214 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-1600 c8f06b1 [Josh Rosen] Remove Thread.sleep calls in FileInputStream CheckpointSuite test. d4f2d87 [Josh Rosen] Refactor file input stream tests to not rely on SystemClock. dda1403 [Josh Rosen] Add StreamingTestWaiter class. 3c3efc3 [Josh Rosen] Synchronize `currentTime` in ManualClock a95ddc4 [Josh Rosen] Modify FileInputDStream to use Clock class. --- .../streaming/dstream/FileInputDStream.scala | 16 +- .../apache/spark/streaming/util/Clock.scala | 6 +- .../streaming/BasicOperationsSuite.scala | 2 +- .../spark/streaming/CheckpointSuite.scala | 248 +++++++++++------- .../spark/streaming/InputStreamsSuite.scala | 69 +++-- .../spark/streaming/TestSuiteBase.scala | 46 +++- 6 files changed, 251 insertions(+), 136 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 5f13fdc5579e..e7c5639a6349 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.dstream import java.io.{IOException, ObjectInputStream} +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.reflect.ClassTag @@ -74,12 +75,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { + // This is a def so that it works during checkpoint recovery: + private def clock = ssc.scheduler.clock + // Data to be saved as part of the streaming checkpoints protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData // Initial ignore threshold based on which old, existing files in the directory (at the time of // starting the streaming application) will be ignored or considered - private val initialModTimeIgnoreThreshold = if (newFilesOnly) System.currentTimeMillis() else 0L + private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.currentTime() else 0L /* * Make sure that the information of files selected in the last few batches are remembered. @@ -91,8 +95,9 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas remember(durationToRemember) // Map of batch-time to selected file info for the remembered batches + // This is a concurrent map because it's also accessed in unit tests @transient private[streaming] var batchTimeToSelectedFiles = - new mutable.HashMap[Time, Array[String]] + new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] // Set of files that were selected in the remembered batches @transient private var recentlySelectedFiles = new mutable.HashSet[String]() @@ -151,7 +156,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): Array[String] = { try { - lastNewFileFindingTime = System.currentTimeMillis + lastNewFileFindingTime = clock.currentTime() // Calculate ignore threshold val modTimeIgnoreThreshold = math.max( @@ -164,7 +169,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) } val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) - val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + val timeTaken = clock.currentTime() - lastNewFileFindingTime logInfo("Finding new files took " + timeTaken + " ms") logDebug("# cached file times = " + fileToModTime.size) if (timeTaken > slideDuration.milliseconds) { @@ -267,7 +272,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () - batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]]() + batchTimeToSelectedFiles = + new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() fileToModTime = new TimeStampedHashMap[String, Long](true) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index 7cd867ce34b8..d6d96d7ba00f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -59,9 +59,11 @@ class SystemClock() extends Clock { private[streaming] class ManualClock() extends Clock { - var time = 0L + private var time = 0L - def currentTime() = time + def currentTime() = this.synchronized { + time + } def setTime(timeToSet: Long) = { this.synchronized { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 199f5e716112..e8f4a7779ec2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -638,7 +638,7 @@ class BasicOperationsSuite extends TestSuiteBase { if (rememberDuration != null) ssc.remember(rememberDuration) val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) + assert(clock.currentTime() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) operatedStream } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 5d232c6ade7a..8f8bc61437ba 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -18,17 +18,18 @@ package org.apache.spark.streaming import java.io.File -import java.nio.charset.Charset -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.reflect.ClassTag +import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.scalatest.concurrent.Eventually._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.util.ManualClock @@ -45,8 +46,6 @@ class CheckpointSuite extends TestSuiteBase { override def batchDuration = Milliseconds(500) - override def actuallyWait = true // to allow checkpoints to be written - override def beforeFunction() { super.beforeFunction() Utils.deleteRecursively(new File(checkpointDir)) @@ -143,7 +142,6 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() advanceTimeWithRealDelay(ssc, 4) ssc.stop() - System.clearProperty("spark.streaming.manualClock.jump") ssc = null } @@ -312,109 +310,161 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } - // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the // failure, are re-processed or not. test("recovery with file input stream") { // Set up the streaming context and input streams + val batchDuration = Seconds(2) // Due to 1-second resolution of setLastModified() on some OS's. val testDir = Utils.createTempDir() - var ssc = new StreamingContext(master, framework, Seconds(1)) - ssc.checkpoint(checkpointDir) - val fileStream = ssc.textFileStream(testDir.toString) - // Making value 3 take large time to process, to ensure that the master - // shuts down in the middle of processing the 3rd batch - val mappedStream = fileStream.map(s => { - val i = s.toInt - if (i == 3) Thread.sleep(2000) - i - }) - - // Reducing over a large window to ensure that recovery from master failure - // requires reprocessing of all the files seen before the failure - val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1)) - val outputBuffer = new ArrayBuffer[Seq[Int]] - var outputStream = new TestOutputStream(reducedStream, outputBuffer) - outputStream.register() - ssc.start() - - // Create files and advance manual clock to process them - // var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - Thread.sleep(1000) - for (i <- Seq(1, 2, 3)) { - Files.write(i + "\n", new File(testDir, i.toString), Charset.forName("UTF-8")) - // wait to make sure that the file is written such that it gets shown in the file listings - Thread.sleep(1000) + val outputBuffer = new ArrayBuffer[Seq[Int]] with SynchronizedBuffer[Seq[Int]] + + /** + * Writes a file named `i` (which contains the number `i`) to the test directory and sets its + * modification time to `clock`'s current time. + */ + def writeFile(i: Int, clock: ManualClock): Unit = { + val file = new File(testDir, i.toString) + Files.write(i + "\n", file, Charsets.UTF_8) + assert(file.setLastModified(clock.currentTime())) + // Check that the file's modification date is actually the value we wrote, since rounding or + // truncation will break the test: + assert(file.lastModified() === clock.currentTime()) } - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0, "No files processed before restart") - ssc.stop() - // Verify whether files created have been recorded correctly or not - var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - def recordedFiles = fileInputDStream.batchTimeToSelectedFiles.values.flatten - assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) - - // Create files while the master is down - for (i <- Seq(4, 5, 6)) { - Files.write(i + "\n", new File(testDir, i.toString), Charset.forName("UTF-8")) - Thread.sleep(1000) + /** + * Returns ids that identify which files which have been recorded by the file input stream. + */ + def recordedFiles(ssc: StreamingContext): Seq[Int] = { + val fileInputDStream = + ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + val filenames = fileInputDStream.batchTimeToSelectedFiles.values.flatten + filenames.map(_.split(File.separator).last.toInt).toSeq.sorted } - // Recover context from checkpoint file and verify whether the files that were - // recorded before failure were saved and successfully recovered - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] - assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + try { + // This is a var because it's re-assigned when we restart from a checkpoint + var clock: ManualClock = null + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.checkpoint(checkpointDir) + clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.textFileStream(testDir.toString) + // Make value 3 take a large time to process, to ensure that the driver + // shuts down in the middle of processing the 3rd batch + CheckpointSuite.batchThreeShouldBlockIndefinitely = true + val mappedStream = fileStream.map(s => { + val i = s.toInt + if (i == 3) { + while (CheckpointSuite.batchThreeShouldBlockIndefinitely) { + Thread.sleep(Long.MaxValue) + } + } + i + }) + + // Reducing over a large window to ensure that recovery from driver failure + // requires reprocessing of all the files seen before the failure + val reducedStream = mappedStream.reduceByWindow(_ + _, batchDuration * 30, batchDuration) + val outputStream = new TestOutputStream(reducedStream, outputBuffer) + outputStream.register() + ssc.start() + + // Advance half a batch so that the first file is created after the StreamingContext starts + clock.addToTime(batchDuration.milliseconds / 2) + // Create files and advance manual clock to process them + for (i <- Seq(1, 2, 3)) { + writeFile(i, clock) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.addToTime(batchDuration.milliseconds) + if (i != 3) { + // Since we want to shut down while the 3rd batch is processing + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === i) + } + } + } + clock.addToTime(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + // Wait until all files have been recorded and all batches have started + assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) + } + // Wait for a checkpoint to be written + val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) + eventually(eventuallyTimeout) { + assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + } + ssc.stop() + // Check that we shut down while the third batch was being processed + assert(batchCounter.getNumCompletedBatches === 2) + assert(outputStream.output.flatten === Seq(1, 3)) + } - // Restart stream computation - ssc.start() - for (i <- Seq(7, 8, 9)) { - Files.write(i + "\n", new File(testDir, i.toString), Charset.forName("UTF-8")) - Thread.sleep(1000) - } - Thread.sleep(1000) - logInfo("Output = " + outputStream.output.mkString("[", ", ", "]")) - assert(outputStream.output.size > 0, "No files processed after restart") - ssc.stop() + // The original StreamingContext has now been stopped. + CheckpointSuite.batchThreeShouldBlockIndefinitely = false - // Verify whether files created while the driver was down have been recorded or not - assert(!recordedFiles.filter(_.endsWith("4")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("5")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("6")).isEmpty) - - // Verify whether new files created after recover have been recorded or not - assert(!recordedFiles.filter(_.endsWith("7")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("8")).isEmpty) - assert(!recordedFiles.filter(_.endsWith("9")).isEmpty) - - // Append the new output to the old buffer - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] - outputBuffer ++= outputStream.output - - val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - val output = outputBuffer.flatMap(x => x) - assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed - output.foreach(o => // To ensure all the inputs are correctly added cumulatively - assert(expectedOutput.contains(o), "Expected value " + o + " not found") - ) - // To ensure that all the inputs were received correctly - assert(expectedOutput.last === output.last) - Utils.deleteRecursively(testDir) + // Create files while the streaming driver is down + for (i <- Seq(4, 5, 6)) { + writeFile(i, clock) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.addToTime(batchDuration.milliseconds) + } + + // Recover context from checkpoint file and verify whether the files that were + // recorded before failure were saved and successfully recovered + logInfo("*********** RESTARTING ************") + withStreamingContext(new StreamingContext(checkpointDir)) { ssc => + // So that the restarted StreamingContext's clock has gone forward in time since failure + ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString) + val oldClockTime = clock.currentTime() + clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val batchCounter = new BatchCounter(ssc) + val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] + // Check that we remember files that were recorded before the restart + assert(recordedFiles(ssc) === Seq(1, 2, 3)) + + // Restart stream computation + ssc.start() + // Verify that the clock has traveled forward to the expected time + eventually(eventuallyTimeout) { + clock.currentTime() === oldClockTime + } + // Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch) + val numBatchesAfterRestart = 4 + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === numBatchesAfterRestart) + } + for ((i, index) <- Seq(7, 8, 9).zipWithIndex) { + writeFile(i, clock) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.addToTime(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1) + } + } + clock.addToTime(batchDuration.milliseconds) + logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]")) + assert(outputStream.output.size > 0, "No files processed after restart") + ssc.stop() + + // Verify whether files created while the driver was down (4, 5, 6) and files created after + // recovery (7, 8, 9) have been recorded + assert(recordedFiles(ssc) === (1 to 9)) + + // Append the new output to the old buffer + outputBuffer ++= outputStream.output + + // Verify whether all the elements received are as expected + val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) + assert(outputBuffer.flatten.toSet === expectedOutput.toSet) + } + } finally { + Utils.deleteRecursively(testDir) + } } @@ -471,12 +521,12 @@ class CheckpointSuite extends TestSuiteBase { */ def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) + logInfo("Manual clock before advancing = " + clock.currentTime()) for (i <- 1 to numBatches.toInt) { clock.addToTime(batchDuration.milliseconds) Thread.sleep(batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.time) + logInfo("Manual clock after advancing = " + clock.currentTime()) Thread.sleep(batchDuration.milliseconds) val outputStream = ssc.graph.getOutputStreams.filter { dstream => @@ -485,3 +535,7 @@ class CheckpointSuite extends TestSuiteBase { outputStream.output.map(_.flatten) } } + +private object CheckpointSuite extends Serializable { + var batchThreeShouldBlockIndefinitely: Boolean = true +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 307052a4a9cb..bddf51e13042 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -28,7 +28,6 @@ import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} -import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.Files @@ -234,45 +233,57 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } def testFileStream(newFilesOnly: Boolean) { - var ssc: StreamingContext = null val testDir: File = null try { + val batchDuration = Seconds(2) val testDir = Utils.createTempDir() + // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") Files.write("0\n", existingFile, Charset.forName("UTF-8")) + assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) - Thread.sleep(1000) // Set up the streaming context and input streams - val newConf = conf.clone.set( - "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") - ssc = new StreamingContext(newConf, batchDuration) - val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( - testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(fileStream, outputBuffer) - outputStream.register() - ssc.start() - - // Create files in the directory - val input = Seq(1, 2, 3, 4, 5) - input.foreach { i => - Thread.sleep(batchDuration.milliseconds) - val file = new File(testDir, i.toString) - Files.write(i + "\n", file, Charset.forName("UTF-8")) - logInfo("Created file " + file) - } + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // This `setTime` call ensures that the clock is past the creation time of `existingFile` + clock.setTime(existingFile.lastModified + batchDuration.milliseconds) + val batchCounter = new BatchCounter(ssc) + val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( + testDir.toString, (x: Path) => true, newFilesOnly = newFilesOnly).map(_._2.toString) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputBuffer) + outputStream.register() + ssc.start() + + // Advance the clock so that the files are created after StreamingContext starts, but + // not enough to trigger a batch + clock.addToTime(batchDuration.milliseconds / 2) + + // Over time, create files in the directory + val input = Seq(1, 2, 3, 4, 5) + input.foreach { i => + val file = new File(testDir, i.toString) + Files.write(i + "\n", file, Charset.forName("UTF-8")) + assert(file.setLastModified(clock.currentTime())) + assert(file.lastModified === clock.currentTime) + logInfo("Created file " + file) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.addToTime(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === i) + } + } - // Verify that all the files have been read - val expectedOutput = if (newFilesOnly) { - input.map(_.toString).toSet - } else { - (Seq(0) ++ input).map(_.toString).toSet - } - eventually(timeout(maxWaitTimeMillis milliseconds), interval(100 milliseconds)) { + // Verify that all the files have been read + val expectedOutput = if (newFilesOnly) { + input.map(_.toString).toSet + } else { + (Seq(0) ++ input).map(_.toString).toSet + } assert(outputBuffer.flatten.toSet === expectedOutput) } } finally { - if (ssc != null) ssc.stop() if (testDir != null) Utils.deleteRecursively(testDir) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 52972f63c6c5..7d82c3e4aadc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -21,11 +21,16 @@ import java.io.{ObjectInputStream, IOException} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.SynchronizedBuffer +import scala.language.implicitConversions import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} +import org.scalatest.concurrent.Eventually.timeout +import org.scalatest.concurrent.PatienceConfiguration import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD @@ -103,6 +108,40 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T], def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) } +/** + * An object that counts the number of started / completed batches. This is implemented using a + * StreamingListener. Constructing a new instance automatically registers a StreamingListener on + * the given StreamingContext. + */ +class BatchCounter(ssc: StreamingContext) { + + // All access to this state should be guarded by `BatchCounter.this.synchronized` + private var numCompletedBatches = 0 + private var numStartedBatches = 0 + + private val listener = new StreamingListener { + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = + BatchCounter.this.synchronized { + numStartedBatches += 1 + BatchCounter.this.notifyAll() + } + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = + BatchCounter.this.synchronized { + numCompletedBatches += 1 + BatchCounter.this.notifyAll() + } + } + ssc.addStreamingListener(listener) + + def getNumCompletedBatches: Int = this.synchronized { + numCompletedBatches + } + + def getNumStartedBatches: Int = this.synchronized { + numStartedBatches + } +} + /** * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. @@ -142,6 +181,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { .setMaster(master) .setAppName(framework) + // Timeout for use in ScalaTest `eventually` blocks + val eventuallyTimeout: PatienceConfiguration.Timeout = timeout(Span(10, ScalaTestSeconds)) + // Default before function for any streaming test suite. Override this // if you want to add your stuff to "before" (i.e., don't call before { } ) def beforeFunction() { @@ -291,7 +333,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - logInfo("Manual clock before advancing = " + clock.time) + logInfo("Manual clock before advancing = " + clock.currentTime()) if (actuallyWait) { for (i <- 1 to numBatches) { logInfo("Actually waiting for " + batchDuration) @@ -301,7 +343,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { } else { clock.addToTime(numBatches * batchDuration.milliseconds) } - logInfo("Manual clock after advancing = " + clock.time) + logInfo("Manual clock after advancing = " + clock.currentTime()) // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() From 5e3ec1110495899a298313c4aa9c6c151c1f54da Mon Sep 17 00:00:00 2001 From: kj-ki Date: Tue, 6 Jan 2015 09:49:37 -0800 Subject: [PATCH 066/215] [Minor] Fix comments for GraphX 2D partitioning strategy The sum of vertices on matrix (v0 to v11) is 12. And, I think one same block overlaps in this strategy. This is minor PR, so I didn't file in JIRA. Author: kj-ki Closes #3904 from kj-ki/fix-partitionstrategy-comments and squashes the following commits: 79829d9 [kj-ki] Fix comments for 2D partitioning. --- .../scala/org/apache/spark/graphx/PartitionStrategy.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 13033fee0e6b..7372dfbd9fe9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -32,9 +32,9 @@ trait PartitionStrategy extends Serializable { object PartitionStrategy { /** * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix, - * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication. + * guaranteeing a `2 * sqrt(numParts) - 1` bound on vertex replication. * - * Suppose we have a graph with 11 vertices that we want to partition + * Suppose we have a graph with 12 vertices that we want to partition * over 9 machines. We can use the following sparse matrix representation: * *
    @@ -61,7 +61,7 @@ object PartitionStrategy {
        * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3,
        * P6)` or the last
        * row of blocks `(P6, P7, P8)`.  As a consequence we can guarantee that `v11` will need to be
    -   * replicated to at most `2 * sqrt(numParts)` machines.
    +   * replicated to at most `2 * sqrt(numParts) - 1` machines.
        *
        * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work
        * balance.  To improve balance we first multiply each vertex id by a large prime to shuffle the
    
    From 4cba6eb42031b1a4cc3308833116ca5d9ccb1a89 Mon Sep 17 00:00:00 2001
    From: Sean Owen 
    Date: Tue, 6 Jan 2015 12:02:08 -0800
    Subject: [PATCH 067/215] SPARK-4159 [CORE] Maven build doesn't run JUnit test
     suites
    
    This PR:
    
    - Reenables `surefire`, and copies config from `scalatest` (which is itself an old fork of `surefire`, so similar)
    - Tells `surefire` to test only Java tests
    - Enables `surefire` and `scalatest` for all children, and in turn eliminates some duplication.
    
    For me this causes the Scala and Java tests to be run once each, it seems, as desired. It doesn't affect the SBT build but works for Maven. I still need to verify that all of the Scala tests and Java tests are being run.
    
    Author: Sean Owen 
    
    Closes #3651 from srowen/SPARK-4159 and squashes the following commits:
    
    2e8a0af [Sean Owen] Remove specialized SPARK_HOME setting for REPL, YARN tests as it appears to be obsolete
    12e4558 [Sean Owen] Append to unit-test.log instead of overwriting, so that both surefire and scalatest output is preserved. Also standardize/correct comments a bit.
    e6f8601 [Sean Owen] Reenable Java tests by reenabling surefire with config cloned from scalatest; centralize test config in the parent
    ---
     bagel/pom.xml                                 | 11 -----
     bagel/src/test/resources/log4j.properties     |  4 +-
     core/pom.xml                                  | 18 --------
     core/src/test/resources/log4j.properties      |  4 +-
     examples/pom.xml                              |  5 ---
     external/flume-sink/pom.xml                   |  9 ----
     .../src/test/resources/log4j.properties       |  3 +-
     external/flume/pom.xml                        | 11 -----
     .../flume/src/test/resources/log4j.properties |  5 +--
     external/kafka/pom.xml                        | 11 -----
     .../kafka/src/test/resources/log4j.properties |  5 +--
     external/mqtt/pom.xml                         | 11 -----
     .../mqtt/src/test/resources/log4j.properties  |  5 +--
     external/twitter/pom.xml                      | 11 -----
     .../src/test/resources/log4j.properties       |  5 +--
     external/zeromq/pom.xml                       | 11 -----
     .../src/test/resources/log4j.properties       |  5 +--
     extras/java8-tests/pom.xml                    | 15 -------
     .../src/test/resources/log4j.properties       |  2 +-
     extras/kinesis-asl/pom.xml                    | 11 -----
     .../src/test/resources/log4j.properties       |  5 ++-
     graphx/pom.xml                                | 11 -----
     graphx/src/test/resources/log4j.properties    |  4 +-
     mllib/pom.xml                                 | 11 -----
     mllib/src/test/resources/log4j.properties     |  4 +-
     network/common/pom.xml                        |  5 ---
     network/shuffle/pom.xml                       |  5 ---
     pom.xml                                       | 45 +++++++++++++++++--
     repl/pom.xml                                  | 14 ------
     repl/src/test/resources/log4j.properties      |  4 +-
     sql/catalyst/pom.xml                          | 10 -----
     sql/core/pom.xml                              | 11 -----
     sql/hive-thriftserver/pom.xml                 |  9 ----
     sql/hive/pom.xml                              |  5 ---
     streaming/pom.xml                             | 10 -----
     streaming/src/test/resources/log4j.properties |  5 +--
     tools/pom.xml                                 |  9 ----
     yarn/pom.xml                                  | 14 ------
     yarn/src/test/resources/log4j.properties      |  4 +-
     39 files changed, 70 insertions(+), 277 deletions(-)
    
    diff --git a/bagel/pom.xml b/bagel/pom.xml
    index 0327ffa40267..3bcd38fa3245 100644
    --- a/bagel/pom.xml
    +++ b/bagel/pom.xml
    @@ -44,11 +44,6 @@
           org.eclipse.jetty
           jetty-server
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -58,11 +53,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
    index 789869f72e3b..853ef0ed2986 100644
    --- a/bagel/src/test/resources/log4j.properties
    +++ b/bagel/src/test/resources/log4j.properties
    @@ -15,10 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file bagel/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/core/pom.xml b/core/pom.xml
    index c5c41b2b5de4..d9a49c9e08af 100644
    --- a/core/pom.xml
    +++ b/core/pom.xml
    @@ -276,11 +276,6 @@
           selenium-java
           test
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.mockito
           mockito-all
    @@ -326,19 +321,6 @@
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
         
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -        
    -          
    -            test
    -            
    -              test
    -            
    -          
    -        
    -      
    -
           
           
             org.apache.maven.plugins
    diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
    index 9dd05f17f012..287c8e356350 100644
    --- a/core/src/test/resources/log4j.properties
    +++ b/core/src/test/resources/log4j.properties
    @@ -15,10 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file core/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/examples/pom.xml b/examples/pom.xml
    index 8713230e1e8e..bdc5d0562f3e 100644
    --- a/examples/pom.xml
    +++ b/examples/pom.xml
    @@ -244,11 +244,6 @@
           algebird-core_${scala.binary.version}
           0.8.1
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
    index 72618b6515f8..71f595d0a680 100644
    --- a/external/flume-sink/pom.xml
    +++ b/external/flume-sink/pom.xml
    @@ -65,11 +65,6 @@
             
           
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scala-lang
           scala-library
    @@ -91,10 +86,6 @@
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
         
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
           
             org.apache.avro
             avro-maven-plugin
    diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties
    index 4411d6e20c52..2a58e9981722 100644
    --- a/external/flume-sink/src/test/resources/log4j.properties
    +++ b/external/flume-sink/src/test/resources/log4j.properties
    @@ -17,9 +17,8 @@
     
     # Set everything to be logged to the file streaming/target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/external/flume/pom.xml b/external/flume/pom.xml
    index a682f0e8471d..0374262212e0 100644
    --- a/external/flume/pom.xml
    +++ b/external/flume/pom.xml
    @@ -61,11 +61,6 @@
             
           
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -85,11 +80,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties
    index 4411d6e20c52..9697237bfa1a 100644
    --- a/external/flume/src/test/resources/log4j.properties
    +++ b/external/flume/src/test/resources/log4j.properties
    @@ -15,11 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file streaming/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
    index b3f44471cd32..b29b0509656b 100644
    --- a/external/kafka/pom.xml
    +++ b/external/kafka/pom.xml
    @@ -74,11 +74,6 @@
           3.2
           test
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -98,11 +93,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties
    index 4411d6e20c52..9697237bfa1a 100644
    --- a/external/kafka/src/test/resources/log4j.properties
    +++ b/external/kafka/src/test/resources/log4j.properties
    @@ -15,11 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file streaming/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
    index d478267b605b..560c8b9d1827 100644
    --- a/external/mqtt/pom.xml
    +++ b/external/mqtt/pom.xml
    @@ -46,11 +46,6 @@
           org.eclipse.paho.client.mqttv3
           1.0.1
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -76,11 +71,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties
    index 4411d6e20c52..9697237bfa1a 100644
    --- a/external/mqtt/src/test/resources/log4j.properties
    +++ b/external/mqtt/src/test/resources/log4j.properties
    @@ -15,11 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file streaming/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
    index 000ace1446e5..da6ffe7662f6 100644
    --- a/external/twitter/pom.xml
    +++ b/external/twitter/pom.xml
    @@ -46,11 +46,6 @@
           twitter4j-stream
           3.0.3
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -70,11 +65,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties
    index 4411d6e20c52..64bfc5745088 100644
    --- a/external/twitter/src/test/resources/log4j.properties
    +++ b/external/twitter/src/test/resources/log4j.properties
    @@ -15,11 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file streaming/target/unit-tests.log
    +# Set everything to be logged to the filetarget/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
    index 29c452093502..2fb5f0ed2f57 100644
    --- a/external/zeromq/pom.xml
    +++ b/external/zeromq/pom.xml
    @@ -46,11 +46,6 @@
           akka-zeromq_${scala.binary.version}
           ${akka.version}
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -70,11 +65,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties
    index 4411d6e20c52..9697237bfa1a 100644
    --- a/external/zeromq/src/test/resources/log4j.properties
    +++ b/external/zeromq/src/test/resources/log4j.properties
    @@ -15,11 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file streaming/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
    index c8477a656631..0fb431808bac 100644
    --- a/extras/java8-tests/pom.xml
    +++ b/extras/java8-tests/pom.xml
    @@ -60,11 +60,6 @@
           junit-interface
           test
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
       
     
       
    @@ -159,16 +154,6 @@
               
             
           
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -        
    -          
    -            test
    -            none
    -          
    -        
    -      
         
       
     
    diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties
    index bb0ab319a008..287c8e356350 100644
    --- a/extras/java8-tests/src/test/resources/log4j.properties
    +++ b/extras/java8-tests/src/test/resources/log4j.properties
    @@ -18,7 +18,7 @@
     # Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
    index c0d3a6111911..c815eda52bda 100644
    --- a/extras/kinesis-asl/pom.xml
    +++ b/extras/kinesis-asl/pom.xml
    @@ -57,11 +57,6 @@
           aws-java-sdk
           ${aws.java.sdk.version}
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.mockito
           mockito-all
    @@ -86,11 +81,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties
    index d9d08f68687d..853ef0ed2986 100644
    --- a/extras/kinesis-asl/src/test/resources/log4j.properties
    +++ b/extras/kinesis-asl/src/test/resources/log4j.properties
    @@ -14,10 +14,11 @@
     # See the License for the specific language governing permissions and
     # limitations under the License.
     #
    +
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
    -# log4j.appender.file=org.apache.log4j.FileAppender
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/graphx/pom.xml b/graphx/pom.xml
    index 9982b36f9b62..91db799d244a 100644
    --- a/graphx/pom.xml
    +++ b/graphx/pom.xml
    @@ -49,11 +49,6 @@
           org.eclipse.jetty
           jetty-server
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -63,11 +58,5 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
       
     
    diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties
    index 9dd05f17f012..287c8e356350 100644
    --- a/graphx/src/test/resources/log4j.properties
    +++ b/graphx/src/test/resources/log4j.properties
    @@ -15,10 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file core/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/mllib/pom.xml b/mllib/pom.xml
    index 0a6dda0ab8c8..219875748168 100644
    --- a/mllib/pom.xml
    +++ b/mllib/pom.xml
    @@ -80,11 +80,6 @@
           org.apache.commons
           commons-math3
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -129,12 +124,6 @@
       
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
    -    
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -    
         
           
             ../python
    diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties
    index a469badf603c..9697237bfa1a 100644
    --- a/mllib/src/test/resources/log4j.properties
    +++ b/mllib/src/test/resources/log4j.properties
    @@ -15,10 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the file core/target/unit-tests.log
    +# Set everything to be logged to the file target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/network/common/pom.xml b/network/common/pom.xml
    index baca859fa501..245a96b8c403 100644
    --- a/network/common/pom.xml
    +++ b/network/common/pom.xml
    @@ -75,11 +75,6 @@
           mockito-all
           test
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
       
     
       
    diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
    index 12468567c3ae..5bfa1ac9c373 100644
    --- a/network/shuffle/pom.xml
    +++ b/network/shuffle/pom.xml
    @@ -83,11 +83,6 @@
           mockito-all
           test
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
       
     
       
    diff --git a/pom.xml b/pom.xml
    index 05f59a9b4140..46ff211f9116 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -256,7 +256,7 @@
           1.0.0
         
         
    @@ -266,6 +266,15 @@
           2.3.7
           provided
         
    +    
    +    
    +      org.scalatest
    +      scalatest_${scala.binary.version}
    +      test
    +    
       
       
         
    @@ -935,19 +944,38 @@
                 true
               
             
    +        
             
               org.apache.maven.plugins
               maven-surefire-plugin
    -          2.17
    +          2.18
    +          
               
    -            
    -            true
    +            
    +              **/Test*.java
    +              **/*Test.java
    +              **/*TestCase.java
    +              **/*Suite.java
    +            
    +            ${project.build.directory}/surefire-reports
    +            -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
    +            
    +              true
    +              ${session.executionRootDirectory}
    +              1
    +              false
    +              false
    +              ${test_classpath}
    +              true
    +            
               
             
    +        
             
               org.scalatest
               scalatest-maven-plugin
               1.0
    +          
               
                 ${project.build.directory}/surefire-reports
                 .
    @@ -1159,6 +1187,15 @@
               
             
           
    +      
    +      
    +        org.apache.maven.plugins
    +        maven-surefire-plugin
    +      
    +      
    +        org.scalatest
    +        scalatest-maven-plugin
    +      
         
       
     
    diff --git a/repl/pom.xml b/repl/pom.xml
    index 9b2290429fee..97165e024926 100644
    --- a/repl/pom.xml
    +++ b/repl/pom.xml
    @@ -86,11 +86,6 @@
           org.slf4j
           jul-to-slf4j
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -115,15 +110,6 @@
               true
             
           
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -        
    -          
    -            ${basedir}/..
    -          
    -        
    -      
           
           
             org.codehaus.mojo
    diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
    index 52098993f5c3..e7e4a4113174 100644
    --- a/repl/src/test/resources/log4j.properties
    +++ b/repl/src/test/resources/log4j.properties
    @@ -15,10 +15,10 @@
     # limitations under the License.
     #
     
    -# Set everything to be logged to the repl/target/unit-tests.log
    +# Set everything to be logged to the target/unit-tests.log
     log4j.rootCategory=INFO, file
     log4j.appender.file=org.apache.log4j.FileAppender
    -log4j.appender.file.append=false
    +log4j.appender.file.append=true
     log4j.appender.file.file=target/unit-tests.log
     log4j.appender.file.layout=org.apache.log4j.PatternLayout
     log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
    diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
    index 1caa297e24e3..a1947fb022e5 100644
    --- a/sql/catalyst/pom.xml
    +++ b/sql/catalyst/pom.xml
    @@ -50,11 +50,6 @@
           spark-core_${scala.binary.version}
           ${project.version}
         
    -    
    -      org.scalatest
    -      scalatest_${scala.binary.version}
    -      test
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    @@ -65,11 +60,6 @@
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
         
    -      
    -        org.scalatest
    -        scalatest-maven-plugin
    -      
    -
           
    +    
    +      hadoop-provided
    +      
    +        provided
    +      
    +    
    +    
    +      hive-provided
    +      
    +        provided
    +      
    +    
    +    
    +      parquet-provided
    +      
    +        provided
    +      
    +    
       
     
    diff --git a/bagel/pom.xml b/bagel/pom.xml
    index 3bcd38fa3245..510e92640eff 100644
    --- a/bagel/pom.xml
    +++ b/bagel/pom.xml
    @@ -40,10 +40,6 @@
           spark-core_${scala.binary.version}
           ${project.version}
         
    -    
    -      org.eclipse.jetty
    -      jetty-server
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
    index a4c099fb45b1..088f993954d9 100644
    --- a/bin/compute-classpath.cmd
    +++ b/bin/compute-classpath.cmd
    @@ -109,6 +109,13 @@ if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
       set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
     :no_yarn_conf_dir
     
    +rem To allow for distributions to append needed libraries to the classpath (e.g. when
    +rem using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
    +rem append it to tbe final classpath.
    +if not "x%$SPARK_DIST_CLASSPATH%"=="x" (
    +  set CLASSPATH=%CLASSPATH%;%SPARK_DIST_CLASSPATH%
    +)
    +
     rem A bit of a hack to allow calling this script within run2.cmd without seeing output
     if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
     
    diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
    index a31ea73d3ce1..8f3b396ffd08 100755
    --- a/bin/compute-classpath.sh
    +++ b/bin/compute-classpath.sh
    @@ -146,4 +146,11 @@ if [ -n "$YARN_CONF_DIR" ]; then
       CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
     fi
     
    +# To allow for distributions to append needed libraries to the classpath (e.g. when
    +# using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
    +# append it to tbe final classpath.
    +if [ -n "$SPARK_DIST_CLASSPATH" ]; then
    +  CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH"
    +fi
    +
     echo "$CLASSPATH"
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
    index 8c7de75600b5..7eb87a564d6f 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
    @@ -55,19 +55,26 @@ private[spark] class SparkDeploySchedulerBackend(
           "{{WORKER_URL}}")
         val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
           .map(Utils.splitCommandString).getOrElse(Seq.empty)
    -    val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp =>
    -      cp.split(java.io.File.pathSeparator)
    -    }
    -    val libraryPathEntries =
    -      sc.conf.getOption("spark.executor.extraLibraryPath").toSeq.flatMap { cp =>
    -        cp.split(java.io.File.pathSeparator)
    +    val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
    +      .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
    +    val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
    +      .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
    +
    +    // When testing, expose the parent class path to the child. This is processed by
    +    // compute-classpath.{cmd,sh} and makes all needed jars available to child processes
    +    // when the assembly is built with the "*-provided" profiles enabled.
    +    val testingClassPath =
    +      if (sys.props.contains("spark.testing")) {
    +        sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
    +      } else {
    +        Nil
           }
     
         // Start executors with a few necessary configs for registering with the scheduler
         val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
         val javaOpts = sparkJavaOpts ++ extraJavaOpts
         val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
    -      args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
    +      args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
         val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
         val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
           appUIAddress, sc.eventLogDir)
    diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
    index 9d6b6161ce4d..c4f1898a2db1 100644
    --- a/core/src/main/scala/org/apache/spark/util/Utils.scala
    +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
    @@ -990,11 +990,12 @@ private[spark] object Utils extends Logging {
         for ((key, value) <- extraEnvironment) {
           environment.put(key, value)
         }
    +
         val process = builder.start()
         new Thread("read stderr for " + command(0)) {
           override def run() {
             for (line <- Source.fromInputStream(process.getErrorStream).getLines()) {
    -          System.err.println(line)
    +          logInfo(line)
             }
           }
         }.start()
    @@ -1089,7 +1090,7 @@ private[spark] object Utils extends Logging {
         var firstUserLine = 0
         var insideSpark = true
         var callStack = new ArrayBuffer[String]() :+ ""
    - 
    +
         Thread.currentThread.getStackTrace().foreach { ste: StackTraceElement =>
           // When running under some profilers, the current stack trace might contain some bogus
           // frames. This is intended to ensure that we don't crash in these situations by
    diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
    index 541d8eac8055..8a54360e8179 100644
    --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
    @@ -35,7 +35,7 @@ class DriverSuite extends FunSuite with Timeouts {
         forAll(masters) { (master: String) =>
           failAfter(60 seconds) {
             Utils.executeAndGetOutput(
    -          Seq("./bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master),
    +          Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master),
               new File(sparkHome),
               Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
           }
    diff --git a/examples/pom.xml b/examples/pom.xml
    index bdc5d0562f3e..002d4458c4b3 100644
    --- a/examples/pom.xml
    +++ b/examples/pom.xml
    @@ -98,143 +98,145 @@
           ${project.version}
         
         
    -      org.eclipse.jetty
    -      jetty-server
    +      org.apache.hbase
    +      hbase-testing-util
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +      
    +        
    +          
    +          org.apache.hbase
    +          hbase-annotations
    +        
    +        
    +          org.jruby
    +          jruby-complete
    +        
    +      
    +    
    +    
    +      org.apache.hbase
    +      hbase-protocol
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +    
    +    
    +      org.apache.hbase
    +      hbase-common
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +      
    +        
    +          
    +          org.apache.hbase
    +          hbase-annotations
    +        
    +      
    +    
    +    
    +      org.apache.hbase
    +      hbase-client
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +      
    +        
    +          
    +          org.apache.hbase
    +          hbase-annotations
    +        
    +       
    +        io.netty
    +        netty
    +       
    +     
    +    
    +    
    +      org.apache.hbase
    +      hbase-server
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +      
    +        
    +          
    +          org.apache.hbase
    +          hbase-annotations
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-core
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-client
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-mapreduce-client-jobclient
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-mapreduce-client-core
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-auth
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-annotations
    +        
    +        
    +          org.apache.hadoop
    +          hadoop-hdfs
    +        
    +        
    +          org.apache.hbase
    +          hbase-hadoop1-compat
    +        
    +        
    +          org.apache.commons
    +          commons-math
    +        
    +        
    +          com.sun.jersey
    +          jersey-core
    +        
    +        
    +          org.slf4j
    +          slf4j-api
    +        
    +        
    +          com.sun.jersey
    +          jersey-server
    +        
    +        
    +          com.sun.jersey
    +          jersey-core
    +        
    +        
    +          com.sun.jersey
    +          jersey-json
    +        
    +        
    +          
    +          commons-io
    +          commons-io
    +        
    +      
    +    
    +    
    +      org.apache.hbase
    +      hbase-hadoop-compat
    +      ${hbase.version}
    +      ${hbase.deps.scope}
    +    
    +    
    +      org.apache.hbase
    +      hbase-hadoop-compat
    +      ${hbase.version}
    +      test-jar
    +      test
         
    -      
    -        org.apache.hbase
    -        hbase-testing-util
    -        ${hbase.version}
    -        
    -          
    -            
    -            org.apache.hbase
    -            hbase-annotations
    -          
    -          
    -            org.jruby
    -            jruby-complete
    -          
    -        
    -      
    -      
    -        org.apache.hbase
    -        hbase-protocol
    -        ${hbase.version}
    -      
    -      
    -        org.apache.hbase
    -        hbase-common
    -        ${hbase.version}
    -        
    -          
    -            
    -            org.apache.hbase
    -            hbase-annotations
    -          
    -        
    -      
    -      
    -        org.apache.hbase
    -        hbase-client
    -        ${hbase.version}
    -        
    -          
    -            
    -            org.apache.hbase
    -            hbase-annotations
    -          
    -         
    -          io.netty
    -          netty
    -         
    -       
    -      
    -      
    -        org.apache.hbase
    -        hbase-server
    -        ${hbase.version}
    -        
    -          
    -            org.apache.hadoop
    -            hadoop-core
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-client
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-mapreduce-client-jobclient
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-mapreduce-client-core
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-auth
    -          
    -          
    -            
    -            org.apache.hbase
    -            hbase-annotations
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-annotations
    -          
    -          
    -            org.apache.hadoop
    -            hadoop-hdfs
    -          
    -          
    -            org.apache.hbase
    -            hbase-hadoop1-compat
    -          
    -          
    -            org.apache.commons
    -            commons-math
    -          
    -          
    -            com.sun.jersey
    -            jersey-core
    -          
    -          
    -            org.slf4j
    -            slf4j-api
    -          
    -          
    -            com.sun.jersey
    -            jersey-server
    -          
    -          
    -            com.sun.jersey
    -            jersey-core
    -          
    -          
    -            com.sun.jersey
    -            jersey-json
    -          
    -          
    -            
    -            commons-io
    -            commons-io
    -          
    -        
    -      
    -      
    -        org.apache.hbase
    -        hbase-hadoop-compat
    -        ${hbase.version}
    -      
    -      
    -        org.apache.hbase
    -        hbase-hadoop-compat
    -        ${hbase.version}
    -        test-jar
    -        test
    -      
         
           org.apache.commons
           commons-math3
    @@ -308,31 +310,6 @@
           
             org.apache.maven.plugins
             maven-shade-plugin
    -        
    -          false
    -          ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar
    -          
    -            
    -              *:*
    -            
    -          
    -          
    -            
    -              com.google.guava:guava
    -              
    -                com/google/common/base/Optional*
    -              
    -            
    -            
    -              *:*
    -              
    -                META-INF/*.SF
    -                META-INF/*.DSA
    -                META-INF/*.RSA
    -              
    -            
    -          
    -        
             
               
                 package
    @@ -340,6 +317,34 @@
                   shade
                 
                 
    +            false
    +            ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar
    +            
    +              
    +                *:*
    +              
    +            
    +            
    +              
    +                com.google.guava:guava
    +                
    +                  
    +                  **
    +                
    +              
    +              
    +                *:*
    +                
    +                  META-INF/*.SF
    +                  META-INF/*.DSA
    +                  META-INF/*.RSA
    +                
    +              
    +            
                   
                     
                       com.google
    @@ -411,7 +416,7 @@
           
         
         
    -      
           scala-2.10
           
    @@ -449,5 +454,37 @@
             
           
         
    +
    +    
    +    
    +      flume-provided
    +      
    +        provided
    +      
    +    
    +    
    +      hadoop-provided
    +      
    +        provided
    +      
    +    
    +    
    +      hbase-provided
    +      
    +        provided
    +      
    +    
    +    
    +      hive-provided
    +      
    +        provided
    +      
    +    
    +    
    +      parquet-provided
    +      
    +        provided
    +      
    +    
       
     
    diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
    index 71f595d0a680..0706f1ebf66e 100644
    --- a/external/flume-sink/pom.xml
    +++ b/external/flume-sink/pom.xml
    @@ -38,32 +38,10 @@
         
           org.apache.flume
           flume-ng-sdk
    -      ${flume.version}
    -      
    -        
    -          io.netty
    -          netty
    -        
    -        
    -          org.apache.thrift
    -          libthrift
    -        
    -      
         
         
           org.apache.flume
           flume-ng-core
    -      ${flume.version}
    -      
    -        
    -          io.netty
    -          netty
    -        
    -        
    -          org.apache.thrift
    -          libthrift
    -        
    -      
         
         
           org.scala-lang
    diff --git a/external/flume/pom.xml b/external/flume/pom.xml
    index 0374262212e0..1f2681394c58 100644
    --- a/external/flume/pom.xml
    +++ b/external/flume/pom.xml
    @@ -46,20 +46,13 @@
           spark-streaming-flume-sink_${scala.binary.version}
           ${project.version}
         
    +    
    +      org.apache.flume
    +      flume-ng-core
    +    
         
           org.apache.flume
           flume-ng-sdk
    -      ${flume.version}
    -      
    -        
    -          io.netty
    -          netty
    -        
    -        
    -          org.apache.thrift
    -          libthrift
    -        
    -      
         
         
           org.scalacheck
    diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
    index 2fb5f0ed2f57..e919c2c9b19e 100644
    --- a/external/zeromq/pom.xml
    +++ b/external/zeromq/pom.xml
    @@ -44,7 +44,6 @@
         
           ${akka.group}
           akka-zeromq_${scala.binary.version}
    -      ${akka.version}
         
         
           org.scalacheck
    diff --git a/graphx/pom.xml b/graphx/pom.xml
    index 91db799d244a..72374aae6da9 100644
    --- a/graphx/pom.xml
    +++ b/graphx/pom.xml
    @@ -45,10 +45,6 @@
           jblas
           ${jblas.version}
         
    -    
    -      org.eclipse.jetty
    -      jetty-server
    -    
         
           org.scalacheck
           scalacheck_${scala.binary.version}
    diff --git a/mllib/pom.xml b/mllib/pom.xml
    index 219875748168..a0bda89ccaa7 100644
    --- a/mllib/pom.xml
    +++ b/mllib/pom.xml
    @@ -29,7 +29,7 @@
       spark-mllib_2.10
       
         mllib
    -    
    +  
       jar
       Spark Project ML Library
       http://spark.apache.org/
    @@ -50,10 +50,6 @@
           spark-sql_${scala.binary.version}
           ${project.version}
         
    -    
    -      org.eclipse.jetty
    -      jetty-server
    -    
         
           org.jblas
           jblas
    diff --git a/pom.xml b/pom.xml
    index 46ff211f9116..703e5c47bf59 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -123,8 +123,10 @@
         2.4.1
         ${hadoop.version}
         0.94.6
    +    hbase
         1.4.0
         3.4.5
    +    org.spark-project.hive
         
         0.13.1a
         
    @@ -143,13 +145,36 @@
         4.2.6
         3.1.1
         ${project.build.directory}/spark-test-classpath.txt
    -    64m
    -    512m
         2.10.4
         2.10
         ${scala.version}
         org.scala-lang
    -    1.8.8
    +    1.8.8
    +    1.1.1.6
    +
    +    
    +    compile
    +    compile
    +    compile
    +    compile
    +    compile
    +
    +    
    +    ${session.executionRootDirectory}
    +
    +    64m
    +    512m
    +    512m
       
     
       
    @@ -244,21 +269,20 @@
           
         
       
    -
       
    -  
    +    
         
           org.spark-project.spark
           unused
           1.0.0
         
         
         
           org.codehaus.groovy
    @@ -369,11 +393,13 @@
             org.slf4j
             slf4j-api
             ${slf4j.version}
    +        ${hadoop.deps.scope}
           
           
             org.slf4j
             slf4j-log4j12
             ${slf4j.version}
    +        ${hadoop.deps.scope}
           
           
             org.slf4j
    @@ -390,6 +416,7 @@
             log4j
             log4j
             ${log4j.version}
    +        ${hadoop.deps.scope}
           
           
             com.ning
    @@ -399,7 +426,8 @@
           
             org.xerial.snappy
             snappy-java
    -        1.1.1.6
    +        ${snappy.version}
    +        ${hadoop.deps.scope}
           
           
             net.jpountz.lz4
    @@ -427,6 +455,7 @@
             com.google.protobuf
             protobuf-java
             ${protobuf.version}
    +        ${hadoop.deps.scope}
           
           
             ${akka.group}
    @@ -448,6 +477,17 @@
             akka-testkit_${scala.binary.version}
             ${akka.version}
           
    +      
    +        ${akka.group}
    +        akka-zeromq_${scala.binary.version}
    +        ${akka.version}
    +        
    +          
    +            ${akka.group}
    +            akka-actor_${scala.binary.version}
    +          
    +        
    +      
           
             org.apache.mesos
             mesos
    @@ -577,6 +617,7 @@
             org.apache.curator
             curator-recipes
             2.4.0
    +        ${hadoop.deps.scope}
             
               
                 org.jboss.netty
    @@ -588,6 +629,7 @@
             org.apache.hadoop
             hadoop-client
             ${hadoop.version}
    +        ${hadoop.deps.scope}
             
               
                 asm
    @@ -623,11 +665,13 @@
             org.apache.avro
             avro
             ${avro.version}
    +        ${hadoop.deps.scope}
           
           
             org.apache.avro
             avro-ipc
             ${avro.version}
    +        ${hadoop.deps.scope}
             
               
                 io.netty
    @@ -656,6 +700,7 @@
             avro-mapred
             ${avro.version}
             ${avro.mapred.classifier}
    +        ${hive.deps.scope}
             
               
                 io.netty
    @@ -684,6 +729,7 @@
             net.java.dev.jets3t
             jets3t
             ${jets3t.version}
    +        ${hadoop.deps.scope}
             
               
                 commons-logging
    @@ -695,6 +741,7 @@
             org.apache.hadoop
             hadoop-yarn-api
             ${yarn.version}
    +        ${hadoop.deps.scope}
             
               
                 javax.servlet
    @@ -722,6 +769,7 @@
             org.apache.hadoop
             hadoop-yarn-common
             ${yarn.version}
    +        ${hadoop.deps.scope}
             
               
                 asm
    @@ -778,6 +826,7 @@
             org.apache.hadoop
             hadoop-yarn-server-web-proxy
             ${yarn.version}
    +        ${hadoop.deps.scope}
             
               
                 asm
    @@ -805,6 +854,7 @@
             org.apache.hadoop
             hadoop-yarn-client
             ${yarn.version}
    +        ${hadoop.deps.scope}
             
               
                 asm
    @@ -829,15 +879,126 @@
             
           
           
    -        
    -        org.codehaus.jackson
    -        jackson-mapper-asl
    -        ${jackson.version}
    +        org.apache.zookeeper
    +        zookeeper
    +        ${zookeeper.version}
    +        ${hadoop.deps.scope}
           
           
             org.codehaus.jackson
             jackson-core-asl
    -        ${jackson.version}
    +        ${codehaus.jackson.version}
    +        ${hadoop.deps.scope}
    +      
    +      
    +        org.codehaus.jackson
    +        jackson-mapper-asl
    +        ${codehaus.jackson.version}
    +        ${hadoop.deps.scope}
    +      
    +      
    +        ${hive.group}
    +        hive-beeline
    +        ${hive.version}
    +        ${hive.deps.scope}
    +      
    +      
    +        ${hive.group}
    +        hive-cli
    +        ${hive.version}
    +        ${hive.deps.scope}
    +      
    +      
    +        ${hive.group}
    +        hive-exec
    +        ${hive.version}
    +        ${hive.deps.scope}
    +        
    +          
    +            commons-logging
    +            commons-logging
    +          
    +          
    +            com.esotericsoftware.kryo
    +            kryo
    +          
    +        
    +      
    +      
    +        ${hive.group}
    +        hive-jdbc
    +        ${hive.version}
    +        ${hive.deps.scope}
    +      
    +      
    +        ${hive.group}
    +        hive-metastore
    +        ${hive.version}
    +        ${hive.deps.scope}
    +      
    +      
    +        ${hive.group}
    +        hive-serde
    +        ${hive.version}
    +        ${hive.deps.scope}
    +        
    +          
    +            commons-logging
    +            commons-logging
    +          
    +          
    +            commons-logging
    +            commons-logging-api
    +          
    +        
    +      
    +      
    +        com.twitter
    +        parquet-column
    +        ${parquet.version}
    +        ${parquet.deps.scope}
    +      
    +      
    +        com.twitter
    +        parquet-hadoop
    +        ${parquet.version}
    +        ${parquet.deps.scope}
    +      
    +      
    +        org.apache.flume
    +        flume-ng-core
    +        ${flume.version}
    +        ${flume.deps.scope}
    +        
    +          
    +            io.netty
    +            netty
    +          
    +          
    +            org.apache.thrift
    +            libthrift
    +          
    +          
    +            org.mortbay.jetty
    +            servlet-api
    +          
    +        
    +      
    +      
    +        org.apache.flume
    +        flume-ng-sdk
    +        ${flume.version}
    +        ${flume.deps.scope}
    +        
    +          
    +            io.netty
    +            netty
    +          
    +          
    +            org.apache.thrift
    +            libthrift
    +          
    +        
           
         
       
    @@ -914,6 +1075,7 @@
                   -Xmx1024m
                   -XX:PermSize=${PermGen}
                   -XX:MaxPermSize=${MaxPermGen}
    +              -XX:ReservedCodeCacheSize=${CodeCacheSize}
                 
                 
                   -source
    @@ -980,15 +1142,21 @@
                 ${project.build.directory}/surefire-reports
                 .
                 SparkTestSuite.txt
    -            -ea -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
    +            -ea -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize}
                 
    +            
    +              
    +              ${test_classpath}
    +            
                 
                   true
    -              ${session.executionRootDirectory}
    +              ${spark.test.home}
                   1
                   false
                   false
    -              ${test_classpath}
                   true
                 
               
    @@ -1011,11 +1179,6 @@
               maven-antrun-plugin
               1.7
             
    -        
    -          org.apache.maven.plugins
    -          maven-shade-plugin
    -          2.2
    -        
             
               org.apache.maven.plugins
               maven-source-plugin
    @@ -1104,6 +1267,7 @@
           
             org.apache.maven.plugins
             maven-shade-plugin
    +        2.2
             
               false
               
    @@ -1373,53 +1537,6 @@
           
         
     
    -    
    -    
    -      hadoop-provided
    -      
    -        
    -          org.apache.hadoop
    -          hadoop-client
    -          provided
    -        
    -        
    -          org.apache.hadoop
    -          hadoop-yarn-api
    -          provided
    -        
    -        
    -          org.apache.hadoop
    -          hadoop-yarn-common
    -          provided
    -        
    -        
    -          org.apache.hadoop
    -          hadoop-yarn-server-web-proxy
    -          provided
    -        
    -        
    -          org.apache.hadoop
    -          hadoop-yarn-client
    -          provided
    -        
    -        
    -          org.apache.avro
    -          avro
    -          provided
    -        
    -        
    -          org.apache.avro
    -          avro-ipc
    -          provided
    -        
    -        
    -          org.apache.zookeeper
    -          zookeeper
    -          ${zookeeper.version}
    -          provided
    -        
    -      
    -    
         
           hive-thriftserver
           
    @@ -1472,5 +1589,25 @@
           
         
     
    +    
    +    
    +      flume-provided
    +    
    +    
    +      hadoop-provided
    +    
    +    
    +      hbase-provided
    +    
    +    
    +      hive-provided
    +    
    +    
    +      parquet-provided
    +    
       
     
    diff --git a/repl/pom.xml b/repl/pom.xml
    index 97165e024926..0bc8bccf90a6 100644
    --- a/repl/pom.xml
    +++ b/repl/pom.xml
    @@ -68,10 +68,6 @@
           ${project.version}
           test
         
    -    
    -      org.eclipse.jetty
    -      jetty-server
    -    
         
           org.scala-lang
           scala-compiler
    @@ -103,13 +99,6 @@
               true
             
           
    -      
    -        org.apache.maven.plugins
    -        maven-install-plugin
    -        
    -          true
    -        
    -      
           
           
             org.codehaus.mojo
    diff --git a/sql/core/pom.xml b/sql/core/pom.xml
    index 023ce2041bb8..3e9ef07df9db 100644
    --- a/sql/core/pom.xml
    +++ b/sql/core/pom.xml
    @@ -56,12 +56,10 @@
         
           com.twitter
           parquet-column
    -      ${parquet.version}
         
         
           com.twitter
           parquet-hadoop
    -      ${parquet.version}
         
         
           com.fasterxml.jackson.core
    diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
    index d3a517375cf2..259eef0b80d0 100644
    --- a/sql/hive-thriftserver/pom.xml
    +++ b/sql/hive-thriftserver/pom.xml
    @@ -42,19 +42,16 @@
           ${project.version}
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-cli
    -      ${hive.version}
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-jdbc
    -      ${hive.version}
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-beeline
    -      ${hive.version}
         
       
       
    diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
    index e8ffbc5b954d..60953576d0e3 100644
    --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
    +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
    @@ -48,6 +48,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
              |  --master local
              |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl
              |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
    +         |  --driver-class-path ${sys.props("java.class.path")}
            """.stripMargin.split("\\s+").toSeq ++ extraArgs
         }
     
    @@ -70,7 +71,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
         }
     
         // Searching expected output line from both stdout and stderr of the CLI process
    -    val process = (Process(command) #< queryStream).run(
    +    val process = (Process(command, None) #< queryStream).run(
           ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
     
         try {
    diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
    index 94d5ed4f1d15..7814aa38f414 100644
    --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
    +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
    @@ -142,6 +142,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
                  |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
                  |  --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http
                  |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port
    +             |  --driver-class-path ${sys.props("java.class.path")}
                """.stripMargin.split("\\s+").toSeq
           } else {
               s"""$startScript
    @@ -151,6 +152,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
                  |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
                  |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
                  |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
    +             |  --driver-class-path ${sys.props("java.class.path")}
                """.stripMargin.split("\\s+").toSeq
           }
     
    @@ -179,8 +181,9 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
           }
         }
     
    -    // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
    -    val env = Seq("SPARK_TESTING" -> "0")
    +    val env = Seq(
    +      // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
    +      "SPARK_TESTING" -> "0")
     
         Process(command, None, env: _*).run(ProcessLogger(
           captureThriftServerOutput("stdout"),
    @@ -214,7 +217,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
         } finally {
           warehousePath.delete()
           metastorePath.delete()
    -      Process(stopScript).run().exitValue()
    +      Process(stopScript, None, env: _*).run().exitValue()
           // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
           Thread.sleep(3.seconds.toMillis)
           Option(logTailingProcess).map(_.destroy())
    diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
    index 46aacad01113..58b0722464be 100644
    --- a/sql/hive/pom.xml
    +++ b/sql/hive/pom.xml
    @@ -47,9 +47,8 @@
           ${project.version}
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-metastore
    -      ${hive.version}
         
         
           commons-httpclient
    @@ -57,51 +56,27 @@
           3.1
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-exec
    -      ${hive.version}
    -      
    -        
    -          commons-logging
    -          commons-logging
    -        
    -        
    -          com.esotericsoftware.kryo
    -          kryo
    -        
    -      
         
         
           org.codehaus.jackson
           jackson-mapper-asl
         
         
    -      org.spark-project.hive
    +      ${hive.group}
           hive-serde
    -      ${hive.version}
    -      
    -        
    -          commons-logging
    -          commons-logging
    -        
    -        
    -          commons-logging
    -          commons-logging-api
    -        
    -      
         
         
         
           org.apache.avro
           avro
    -      ${avro.version}
         
         
         
           org.apache.avro
           avro-mapred
    -      ${avro.version}
           ${avro.mapred.classifier}
         
         
    diff --git a/streaming/pom.xml b/streaming/pom.xml
    index 2023210d9b9b..d3c6d0347a62 100644
    --- a/streaming/pom.xml
    +++ b/streaming/pom.xml
    @@ -68,13 +68,13 @@
         target/scala-${scala.binary.version}/classes
         target/scala-${scala.binary.version}/test-classes
         
    -      
           
    diff --git a/yarn/pom.xml b/yarn/pom.xml
    index bcb77b3e3c70..b86857db7bde 100644
    --- a/yarn/pom.xml
    +++ b/yarn/pom.xml
    @@ -131,13 +131,6 @@
               true
             
           
    -      
    -        org.apache.maven.plugins
    -        maven-install-plugin
    -        
    -          true
    -        
    -      
         
     
         target/scala-${scala.binary.version}/classes
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    index 8d0543771309..c363d755c175 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    @@ -367,6 +367,10 @@ private[spark] class Client(
           }
         }
     
    +    sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp =>
    +      env(ENV_DIST_CLASSPATH) = dcp
    +    }
    +
         env
       }
     
    @@ -652,6 +656,9 @@ object Client extends Logging {
       val APP_FILE_PERMISSION: FsPermission =
         FsPermission.createImmutable(Integer.parseInt("644", 8).toShort)
     
    +  // Distribution-defined classpath to add to processes
    +  val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH"
    +
       /**
        * Find the user-defined Spark jar if configured, or return the jar containing this
        * class if not.
    
    From 167a5ab0bd1d37f3ac23bec49e484a238610cf75 Mon Sep 17 00:00:00 2001
    From: Nicholas Chammas 
    Date: Thu, 8 Jan 2015 17:42:08 -0800
    Subject: [PATCH 091/215] [SPARK-5122] Remove Shark from spark-ec2
    
    I moved the Spark-Shark version map [to the wiki](https://cwiki.apache.org/confluence/display/SPARK/Spark-Shark+version+mapping).
    
    This PR has a [matching PR in mesos/spark-ec2](https://github.com/mesos/spark-ec2/pull/89).
    
    Author: Nicholas Chammas 
    
    Closes #3939 from nchammas/remove-shark and squashes the following commits:
    
    66e0841 [Nicholas Chammas] fix style
    ceeab85 [Nicholas Chammas] show default Spark GitHub repo
    7270126 [Nicholas Chammas] validate Spark hashes
    db4935d [Nicholas Chammas] validate spark version upfront
    fc0d5b9 [Nicholas Chammas] remove Shark
    ---
     ec2/spark_ec2.py | 78 +++++++++++++++++++++++++++---------------------
     1 file changed, 44 insertions(+), 34 deletions(-)
    
    diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
    index 485eea4f5e68..abab209a05ba 100755
    --- a/ec2/spark_ec2.py
    +++ b/ec2/spark_ec2.py
    @@ -39,10 +39,26 @@
     from optparse import OptionParser
     from sys import stderr
     
    +VALID_SPARK_VERSIONS = set([
    +    "0.7.3",
    +    "0.8.0",
    +    "0.8.1",
    +    "0.9.0",
    +    "0.9.1",
    +    "0.9.2",
    +    "1.0.0",
    +    "1.0.1",
    +    "1.0.2",
    +    "1.1.0",
    +    "1.1.1",
    +    "1.2.0",
    +])
    +
     DEFAULT_SPARK_VERSION = "1.2.0"
    +DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark"
     SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
    -
     MESOS_SPARK_EC2_BRANCH = "branch-1.3"
    +
     # A URL prefix from which to fetch AMI information
     AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
     
    @@ -126,8 +142,8 @@ def parse_args():
             help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)")
         parser.add_option(
             "--spark-git-repo",
    -        default="https://github.com/apache/spark",
    -        help="Github repo from which to checkout supplied commit hash")
    +        default=DEFAULT_SPARK_GITHUB_REPO,
    +        help="Github repo from which to checkout supplied commit hash (default: %default)")
         parser.add_option(
             "--hadoop-major-version", default="1",
             help="Major version of Hadoop (default: %default)")
    @@ -236,6 +252,26 @@ def get_or_make_group(conn, name, vpc_id):
             return conn.create_security_group(name, "Spark EC2 group", vpc_id)
     
     
    +def get_validate_spark_version(version, repo):
    +    if "." in version:
    +        version = version.replace("v", "")
    +        if version not in VALID_SPARK_VERSIONS:
    +            print >> stderr, "Don't know about Spark version: {v}".format(v=version)
    +            sys.exit(1)
    +        return version
    +    else:
    +        github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version)
    +        request = urllib2.Request(github_commit_url)
    +        request.get_method = lambda: 'HEAD'
    +        try:
    +            response = urllib2.urlopen(request)
    +        except urllib2.HTTPError, e:
    +            print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url)
    +            print >> stderr, "Received HTTP response code of {code}.".format(code=e.code)
    +            sys.exit(1)
    +        return version
    +
    +
     # Check whether a given EC2 instance object is in a state we consider active,
     # i.e. not terminating or terminated. We count both stopping and stopped as
     # active since we can restart stopped clusters.
    @@ -243,29 +279,6 @@ def is_active(instance):
         return (instance.state in ['pending', 'running', 'stopping', 'stopped'])
     
     
    -# Return correct versions of Spark and Shark, given the supplied Spark version
    -def get_spark_shark_version(opts):
    -    spark_shark_map = {
    -        "0.7.3": "0.7.1",
    -        "0.8.0": "0.8.0",
    -        "0.8.1": "0.8.1",
    -        "0.9.0": "0.9.0",
    -        "0.9.1": "0.9.1",
    -        # These are dummy versions (no Shark versions after this)
    -        "1.0.0": "1.0.0",
    -        "1.0.1": "1.0.1",
    -        "1.0.2": "1.0.2",
    -        "1.1.0": "1.1.0",
    -        "1.1.1": "1.1.1",
    -        "1.2.0": "1.2.0",
    -    }
    -    version = opts.spark_version.replace("v", "")
    -    if version not in spark_shark_map:
    -        print >> stderr, "Don't know about Spark version: %s" % version
    -        sys.exit(1)
    -    return (version, spark_shark_map[version])
    -
    -
     # Attempt to resolve an appropriate AMI given the architecture and region of the request.
     # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/
     # Last Updated: 2014-06-20
    @@ -619,7 +632,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
                 print slave.public_dns_name
                 ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
     
    -    modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
    +    modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
                    'mapreduce', 'spark-standalone', 'tachyon']
     
         if opts.hadoop_major_version == "1":
    @@ -706,9 +719,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state):
         sys.stdout.flush()
     
         start_time = datetime.now()
    -
         num_attempts = 0
    -    conn = ec2.connect_to_region(opts.region)
     
         while True:
             time.sleep(5 * num_attempts)  # seconds
    @@ -815,13 +826,11 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
         cluster_url = "%s:7077" % active_master
     
         if "." in opts.spark_version:
    -        # Pre-built spark & shark deploy
    -        (spark_v, shark_v) = get_spark_shark_version(opts)
    +        # Pre-built Spark deploy
    +        spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo)
         else:
             # Spark-only custom deploy
             spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version)
    -        shark_v = ""
    -        modules = filter(lambda x: x != "shark", modules)
     
         template_vars = {
             "master_list": '\n'.join([i.public_dns_name for i in master_nodes]),
    @@ -834,7 +843,6 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
             "swap": str(opts.swap),
             "modules": '\n'.join(modules),
             "spark_version": spark_v,
    -        "shark_version": shark_v,
             "hadoop_major_version": opts.hadoop_major_version,
             "spark_worker_instances": "%d" % opts.worker_instances,
             "spark_master_opts": opts.master_opts
    @@ -983,6 +991,8 @@ def real_main():
         (opts, action, cluster_name) = parse_args()
     
         # Input parameter validation
    +    get_validate_spark_version(opts.spark_version, opts.spark_git_repo)
    +
         if opts.wait is not None:
             # NOTE: DeprecationWarnings are silent in 2.7+ by default.
             #       To show them, run Python with the -Wdefault switch.
    
    From f3da4bd7289d493014ad3c5176ada60794dfcfe0 Mon Sep 17 00:00:00 2001
    From: WangTaoTheTonic 
    Date: Fri, 9 Jan 2015 08:10:09 -0600
    Subject: [PATCH 092/215] [SPARK-5169][YARN]fetch the correct max attempts
    
    Soryy for fetching the wrong max attempts in this commit https://github.com/apache/spark/commit/8fdd48959c93b9cf809f03549e2ae6c4687d1fcd.
    We need to fix it now.
    
    tgravescs
    
    If we set an spark.yarn.maxAppAttempts which is larger than `yarn.resourcemanager.am.max-attempts` in yarn side, it will be overrided as described here:
    >The maximum number of application attempts. It's a global setting for all application masters. Each application master can specify its individual maximum number of application attempts via the API, but the individual number cannot be more than the global upper bound. If it is, the resourcemanager will override it. The default number is set to 2, to allow at least one retry for AM.
    
    http://hadoop.apache.org/docs/r2.6.0/hadoop-yarn/hadoop-yarn-common/yarn-default.xml
    
    Author: WangTaoTheTonic 
    
    Closes #3942 from WangTaoTheTonic/HOTFIX and squashes the following commits:
    
    9ac16ce [WangTaoTheTonic] fetch the correct max attempts
    ---
     .../org/apache/spark/deploy/yarn/YarnRMClient.scala  | 12 +++++++++---
     1 file changed, 9 insertions(+), 3 deletions(-)
    
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
    index e183efccbb6f..b45e599588ad 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
    @@ -121,9 +121,15 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
     
       /** Returns the maximum number of attempts to register the AM. */
       def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
    -    sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt).getOrElse(
    -      yarnConf.getInt(
    -        YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS))
    +    val sparkMaxAttempts = sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt)
    +    val yarnMaxAttempts = yarnConf.getInt(
    +      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
    +    val retval: Int = sparkMaxAttempts match {
    +      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
    +      case None => yarnMaxAttempts
    +    }
    +
    +    retval
       }
     
     }
    
    From b4034c3f889bf24f60eb806802866b48e4cbe55c Mon Sep 17 00:00:00 2001
    From: Aaron Davidson 
    Date: Fri, 9 Jan 2015 09:20:16 -0800
    Subject: [PATCH 093/215] [Minor] Fix test RetryingBlockFetcherSuite after
     changed config name
    
    Flakey due to the default retry interval being the same as our test's wait timeout.
    
    Author: Aaron Davidson 
    
    Closes #3972 from aarondav/fix-test and squashes the following commits:
    
    db77cab [Aaron Davidson] [Minor] Fix test after changed config name
    ---
     .../spark/network/shuffle/RetryingBlockFetcherSuite.java      | 4 ++--
     1 file changed, 2 insertions(+), 2 deletions(-)
    
    diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
    index 0191fe529e1b..1ad0d72ae5ec 100644
    --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
    +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
    @@ -54,13 +54,13 @@ public class RetryingBlockFetcherSuite {
       @Before
       public void beforeEach() {
         System.setProperty("spark.shuffle.io.maxRetries", "2");
    -    System.setProperty("spark.shuffle.io.retryWaitMs", "0");
    +    System.setProperty("spark.shuffle.io.retryWait", "0");
       }
     
       @After
       public void afterEach() {
         System.clearProperty("spark.shuffle.io.maxRetries");
    -    System.clearProperty("spark.shuffle.io.retryWaitMs");
    +    System.clearProperty("spark.shuffle.io.retryWait");
       }
     
       @Test
    
    From 547df97715580f99ae573a49a86da12bf20cbc3d Mon Sep 17 00:00:00 2001
    From: Sean Owen 
    Date: Fri, 9 Jan 2015 09:35:46 -0800
    Subject: [PATCH 094/215] SPARK-5136 [DOCS] Improve documentation around
     setting up Spark IntelliJ project
    
    This PR simply points to the IntelliJ wiki page instead of also including IntelliJ notes in the docs. The intent however is to also update the wiki page with updated tips. This is the text I propose for the IntelliJ section on the wiki. I realize it omits some of the existing instructions on the wiki, about enabling Hive, but I think those are actually optional.
    
    ------
    
    IntelliJ supports both Maven- and SBT-based projects. It is recommended, however, to import Spark as a Maven project. Choose "Import Project..." from the File menu, and select the `pom.xml` file in the Spark root directory.
    
    It is fine to leave all settings at their default values in the Maven import wizard, with two caveats. First, it is usually useful to enable "Import Maven projects automatically", sincchanges to the project structure will automatically update the IntelliJ project.
    
    Second, note the step that prompts you to choose active Maven build profiles. As documented above, some build configuration require specific profiles to be enabled. The same profiles that are enabled with `-P[profile name]` above may be enabled on this screen. For example, if developing for Hadoop 2.4 with YARN support, enable profiles `yarn` and `hadoop-2.4`.
    
    These selections can be changed later by accessing the "Maven Projects" tool window from the View menu, and expanding the Profiles section.
    
    "Rebuild Project" can fail the first time the project is compiled, because generate source files are not automatically generated. Try clicking the  "Generate Sources and Update Folders For All Projects" button in the "Maven Projects" tool window to manually generate these sources.
    
    Compilation may fail with an error like "scalac: bad option: -P:/home/jakub/.m2/repository/org/scalamacros/paradise_2.10.4/2.0.1/paradise_2.10.4-2.0.1.jar". If so, go to Preferences > Build, Execution, Deployment > Scala Compiler and clear the "Additional compiler options" field. It will work then although the option will come back when the project reimports.
    
    Author: Sean Owen 
    
    Closes #3952 from srowen/SPARK-5136 and squashes the following commits:
    
    f3baa66 [Sean Owen] Point to new IJ / Eclipse wiki link
    016b7df [Sean Owen] Point to IntelliJ wiki page instead of also including IntelliJ notes in the docs
    ---
     docs/building-spark.md | 5 +++--
     1 file changed, 3 insertions(+), 2 deletions(-)
    
    diff --git a/docs/building-spark.md b/docs/building-spark.md
    index c1bcd91b5b85..fb93017861ed 100644
    --- a/docs/building-spark.md
    +++ b/docs/building-spark.md
    @@ -151,9 +151,10 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m
      $ mvn scala:cc
     ```
     
    -# Using With IntelliJ IDEA
    +# Building Spark with IntelliJ IDEA or Eclipse
     
    -This setup works fine in IntelliJ IDEA 11.1.4. After opening the project via the pom.xml file in the project root folder, you only need to activate either the hadoop1 or hadoop2 profile in the "Maven Properties" popout. We have not tried Eclipse/Scala IDE with this.
    +For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the
    +[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup).
     
     # Building Spark Debian Packages
     
    
    From 1790b38695b46400a24b0b7e278e8e8388748211 Mon Sep 17 00:00:00 2001
    From: Patrick Wendell 
    Date: Fri, 9 Jan 2015 09:40:18 -0800
    Subject: [PATCH 095/215] HOTFIX: Minor improvements to make-distribution.sh
    
    1. Renames $FWDIR to $SPARK_HOME (vast majority of diff).
    2. Use Spark-provided Maven.
    3. Logs build flags in the RELEASE file.
    
    Author: Patrick Wendell 
    
    Closes #3973 from pwendell/master and squashes the following commits:
    
    340a2fa [Patrick Wendell] HOTFIX: Minor improvements to make-distribution.sh
    ---
     make-distribution.sh | 61 ++++++++++++++++++++++++--------------------
     1 file changed, 34 insertions(+), 27 deletions(-)
    
    diff --git a/make-distribution.sh b/make-distribution.sh
    index 45c99e42e5a5..4e2f400be305 100755
    --- a/make-distribution.sh
    +++ b/make-distribution.sh
    @@ -28,18 +28,20 @@ set -o pipefail
     set -e
     
     # Figure out where the Spark framework is installed
    -FWDIR="$(cd "`dirname "$0"`"; pwd)"
    -DISTDIR="$FWDIR/dist"
    +SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
    +DISTDIR="$SPARK_HOME/dist"
     
     SPARK_TACHYON=false
     MAKE_TGZ=false
     NAME=none
    +MVN="$SPARK_HOME/build/mvn"
     
     function exit_with_usage {
       echo "make-distribution.sh - tool for making binary distributions of Spark"
       echo ""
       echo "usage:"
    -  echo "./make-distribution.sh [--name] [--tgz] [--with-tachyon] "
    +  cl_options="[--name] [--tgz] [--mvn ] [--with-tachyon]"
    +  echo "./make-distribution.sh $cl_options "
       echo "See Spark's \"Building Spark\" doc for correct Maven options."
       echo ""
       exit 1
    @@ -71,6 +73,10 @@ while (( "$#" )); do
         --tgz)
           MAKE_TGZ=true
           ;;
    +    --mvn)
    +      MVN="$2"
    +      shift
    +      ;;
         --name)
           NAME="$2"
           shift
    @@ -109,9 +115,9 @@ if which git &>/dev/null; then
         unset GITREV
     fi
     
    -if ! which mvn &>/dev/null; then
    -    echo -e "You need Maven installed to build Spark."
    -    echo -e "Download Maven from https://maven.apache.org/"
    +if ! which $MVN &>/dev/null; then
    +    echo -e "Could not locate Maven command: '$MVN'."
    +    echo -e "Specify the Maven command with the --mvn flag"
         exit -1;
     fi
     
    @@ -119,7 +125,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "
     SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\
         | grep -v "INFO"\
         | tail -n 1)
    -SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\
    +SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\
         | grep -v "INFO"\
         | fgrep --count "hive";\
         # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\
    @@ -161,11 +167,11 @@ else
     fi
     
     # Build uber fat JAR
    -cd "$FWDIR"
    +cd "$SPARK_HOME"
     
     export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
     
    -BUILD_COMMAND="mvn clean package -DskipTests $@"
    +BUILD_COMMAND="$MVN clean package -DskipTests $@"
     
     # Actually build the jar
     echo -e "\nBuilding with..."
    @@ -177,41 +183,42 @@ ${BUILD_COMMAND}
     rm -rf "$DISTDIR"
     mkdir -p "$DISTDIR/lib"
     echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE"
    +echo "Build flags: $@" >> "$DISTDIR/RELEASE"
     
     # Copy jars
    -cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/"
    -cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/"
    +cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/"
    +cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/"
     # This will fail if the -Pyarn profile is not provided
     # In this case, silence the error and ignore the return code of this command
    -cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || :
    +cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || :
     
     # Copy example sources (needed for python and SQL)
     mkdir -p "$DISTDIR/examples/src/main"
    -cp -r "$FWDIR"/examples/src/main "$DISTDIR/examples/src/"
    +cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/"
     
     if [ "$SPARK_HIVE" == "1" ]; then
    -  cp "$FWDIR"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/"
    +  cp "$SPARK_HOME"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/"
     fi
     
     # Copy license and ASF files
    -cp "$FWDIR/LICENSE" "$DISTDIR"
    -cp "$FWDIR/NOTICE" "$DISTDIR"
    +cp "$SPARK_HOME/LICENSE" "$DISTDIR"
    +cp "$SPARK_HOME/NOTICE" "$DISTDIR"
     
    -if [ -e "$FWDIR"/CHANGES.txt ]; then
    -  cp "$FWDIR/CHANGES.txt" "$DISTDIR"
    +if [ -e "$SPARK_HOME"/CHANGES.txt ]; then
    +  cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR"
     fi
     
     # Copy data files
    -cp -r "$FWDIR/data" "$DISTDIR"
    +cp -r "$SPARK_HOME/data" "$DISTDIR"
     
     # Copy other things
     mkdir "$DISTDIR"/conf
    -cp "$FWDIR"/conf/*.template "$DISTDIR"/conf
    -cp "$FWDIR/README.md" "$DISTDIR"
    -cp -r "$FWDIR/bin" "$DISTDIR"
    -cp -r "$FWDIR/python" "$DISTDIR"
    -cp -r "$FWDIR/sbin" "$DISTDIR"
    -cp -r "$FWDIR/ec2" "$DISTDIR"
    +cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf
    +cp "$SPARK_HOME/README.md" "$DISTDIR"
    +cp -r "$SPARK_HOME/bin" "$DISTDIR"
    +cp -r "$SPARK_HOME/python" "$DISTDIR"
    +cp -r "$SPARK_HOME/sbin" "$DISTDIR"
    +cp -r "$SPARK_HOME/ec2" "$DISTDIR"
     
     # Download and copy in tachyon, if requested
     if [ "$SPARK_TACHYON" == "true" ]; then
    @@ -243,9 +250,9 @@ fi
     
     if [ "$MAKE_TGZ" == "true" ]; then
       TARDIR_NAME=spark-$VERSION-bin-$NAME
    -  TARDIR="$FWDIR/$TARDIR_NAME"
    +  TARDIR="$SPARK_HOME/$TARDIR_NAME"
       rm -rf "$TARDIR"
       cp -r "$DISTDIR" "$TARDIR"
    -  tar czf "spark-$VERSION-bin-$NAME.tgz" -C "$FWDIR" "$TARDIR_NAME"
    +  tar czf "spark-$VERSION-bin-$NAME.tgz" -C "$SPARK_HOME" "$TARDIR_NAME"
       rm -rf "$TARDIR"
     fi
    
    From b6aa557300275b835cce7baa7bc8a80eb5425cbb Mon Sep 17 00:00:00 2001
    From: Kay Ousterhout 
    Date: Fri, 9 Jan 2015 09:47:06 -0800
    Subject: [PATCH 096/215] [SPARK-1143] Separate pool tests into their own
     suite.
    
    The current TaskSchedulerImplSuite includes some tests that are
    actually for the TaskSchedulerImpl, but the remainder of the tests avoid using
    the TaskSchedulerImpl entirely, and actually test the pool and scheduling
    algorithm mechanisms. This commit separates the pool/scheduling algorithm
    tests into their own suite, and also simplifies those tests.
    
    The pull request replaces #339.
    
    Author: Kay Ousterhout 
    
    Closes #3967 from kayousterhout/SPARK-1143 and squashes the following commits:
    
    8a898c4 [Kay Ousterhout] [SPARK-1143] Separate pool tests into their own suite.
    ---
     .../apache/spark/scheduler/PoolSuite.scala    | 183 ++++++++++++++
     .../scheduler/TaskSchedulerImplSuite.scala    | 230 ------------------
     2 files changed, 183 insertions(+), 230 deletions(-)
     create mode 100644 core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
    
    diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
    new file mode 100644
    index 000000000000..e8f461e2f56c
    --- /dev/null
    +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
    @@ -0,0 +1,183 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.scheduler
    +
    +import java.util.Properties
    +
    +import org.scalatest.FunSuite
    +
    +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext}
    +
    +/**
    + * Tests that pools and the associated scheduling algorithms for FIFO and fair scheduling work
    + * correctly.
    + */
    +class PoolSuite extends FunSuite with LocalSparkContext {
    +
    +  def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl)
    +    : TaskSetManager = {
    +    val tasks = Array.tabulate[Task[_]](numTasks) { i =>
    +      new FakeTask(i, Nil)
    +    }
    +    new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0)
    +  }
    +
    +  def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int) {
    +    val taskSetQueue = rootPool.getSortedTaskSetQueue
    +    val nextTaskSetToSchedule =
    +      taskSetQueue.find(t => (t.runningTasks + t.tasksSuccessful) < t.numTasks)
    +    assert(nextTaskSetToSchedule.isDefined)
    +    nextTaskSetToSchedule.get.addRunningTask(taskId)
    +    assert(nextTaskSetToSchedule.get.stageId === expectedStageId)
    +  }
    +
    +  test("FIFO Scheduler Test") {
    +    sc = new SparkContext("local", "TaskSchedulerImplSuite")
    +    val taskScheduler = new TaskSchedulerImpl(sc)
    +
    +    val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
    +    val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
    +    schedulableBuilder.buildPools()
    +
    +    val taskSetManager0 = createTaskSetManager(0, 2, taskScheduler)
    +    val taskSetManager1 = createTaskSetManager(1, 2, taskScheduler)
    +    val taskSetManager2 = createTaskSetManager(2, 2, taskScheduler)
    +    schedulableBuilder.addTaskSetManager(taskSetManager0, null)
    +    schedulableBuilder.addTaskSetManager(taskSetManager1, null)
    +    schedulableBuilder.addTaskSetManager(taskSetManager2, null)
    +
    +    scheduleTaskAndVerifyId(0, rootPool, 0)
    +    scheduleTaskAndVerifyId(1, rootPool, 0)
    +    scheduleTaskAndVerifyId(2, rootPool, 1)
    +    scheduleTaskAndVerifyId(3, rootPool, 1)
    +    scheduleTaskAndVerifyId(4, rootPool, 2)
    +    scheduleTaskAndVerifyId(5, rootPool, 2)
    +  }
    +
    +  /**
    +   * This test creates three scheduling pools, and creates task set managers in the first
    +   * two scheduling pools. The test verifies that as tasks are scheduled, the fair scheduling
    +   * algorithm properly orders the two scheduling pools.
    +   */
    +  test("Fair Scheduler Test") {
    +    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
    +    val conf = new SparkConf().set("spark.scheduler.allocation.file", xmlPath)
    +    sc = new SparkContext("local", "TaskSchedulerImplSuite", conf)
    +    val taskScheduler = new TaskSchedulerImpl(sc)
    +
    +    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
    +    val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf)
    +    schedulableBuilder.buildPools()
    +
    +    // Ensure that the XML file was read in correctly.
    +    assert(rootPool.getSchedulableByName("default") != null)
    +    assert(rootPool.getSchedulableByName("1") != null)
    +    assert(rootPool.getSchedulableByName("2") != null)
    +    assert(rootPool.getSchedulableByName("3") != null)
    +    assert(rootPool.getSchedulableByName("1").minShare === 2)
    +    assert(rootPool.getSchedulableByName("1").weight === 1)
    +    assert(rootPool.getSchedulableByName("2").minShare === 3)
    +    assert(rootPool.getSchedulableByName("2").weight === 1)
    +    assert(rootPool.getSchedulableByName("3").minShare === 0)
    +    assert(rootPool.getSchedulableByName("3").weight === 1)
    +
    +    val properties1 = new Properties()
    +    properties1.setProperty("spark.scheduler.pool","1")
    +    val properties2 = new Properties()
    +    properties2.setProperty("spark.scheduler.pool","2")
    +
    +    val taskSetManager10 = createTaskSetManager(0, 1, taskScheduler)
    +    val taskSetManager11 = createTaskSetManager(1, 1, taskScheduler)
    +    val taskSetManager12 = createTaskSetManager(2, 2, taskScheduler)
    +    schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
    +    schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
    +    schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
    +
    +    val taskSetManager23 = createTaskSetManager(3, 2, taskScheduler)
    +    val taskSetManager24 = createTaskSetManager(4, 2, taskScheduler)
    +    schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
    +    schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
    +
    +    // Pool 1 share ratio: 0. Pool 2 share ratio: 0. 1 gets scheduled based on ordering of names.
    +    scheduleTaskAndVerifyId(0, rootPool, 0)
    +    // Pool 1 share ratio: 1/2. Pool 2 share ratio: 0. 2 gets scheduled because ratio is lower.
    +    scheduleTaskAndVerifyId(1, rootPool, 3)
    +    // Pool 1 share ratio: 1/2. Pool 2 share ratio: 1/3. 2 gets scheduled because ratio is lower.
    +    scheduleTaskAndVerifyId(2, rootPool, 3)
    +    // Pool 1 share ratio: 1/2. Pool 2 share ratio: 2/3. 1 gets scheduled because ratio is lower.
    +    scheduleTaskAndVerifyId(3, rootPool, 1)
    +    // Pool 1 share ratio: 1. Pool 2 share ratio: 2/3. 2 gets scheduled because ratio is lower.
    +    scheduleTaskAndVerifyId(4, rootPool, 4)
    +    // Neither pool is needy so ordering is based on number of running tasks.
    +    // Pool 1 running tasks: 2, Pool 2 running tasks: 3. 1 gets scheduled because fewer running
    +    // tasks.
    +    scheduleTaskAndVerifyId(5, rootPool, 2)
    +    // Pool 1 running tasks: 3, Pool 2 running tasks: 3. 1 gets scheduled because of naming
    +    // ordering.
    +    scheduleTaskAndVerifyId(6, rootPool, 2)
    +    // Pool 1 running tasks: 4, Pool 2 running tasks: 3. 2 gets scheduled because fewer running
    +    // tasks.
    +    scheduleTaskAndVerifyId(7, rootPool, 4)
    +  }
    +
    +  test("Nested Pool Test") {
    +    sc = new SparkContext("local", "TaskSchedulerImplSuite")
    +    val taskScheduler = new TaskSchedulerImpl(sc)
    +
    +    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
    +    val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
    +    val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
    +    rootPool.addSchedulable(pool0)
    +    rootPool.addSchedulable(pool1)
    +
    +    val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
    +    val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
    +    pool0.addSchedulable(pool00)
    +    pool0.addSchedulable(pool01)
    +
    +    val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
    +    val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
    +    pool1.addSchedulable(pool10)
    +    pool1.addSchedulable(pool11)
    +
    +    val taskSetManager000 = createTaskSetManager(0, 5, taskScheduler)
    +    val taskSetManager001 = createTaskSetManager(1, 5, taskScheduler)
    +    pool00.addSchedulable(taskSetManager000)
    +    pool00.addSchedulable(taskSetManager001)
    +
    +    val taskSetManager010 = createTaskSetManager(2, 5, taskScheduler)
    +    val taskSetManager011 = createTaskSetManager(3, 5, taskScheduler)
    +    pool01.addSchedulable(taskSetManager010)
    +    pool01.addSchedulable(taskSetManager011)
    +
    +    val taskSetManager100 = createTaskSetManager(4, 5, taskScheduler)
    +    val taskSetManager101 = createTaskSetManager(5, 5, taskScheduler)
    +    pool10.addSchedulable(taskSetManager100)
    +    pool10.addSchedulable(taskSetManager101)
    +
    +    val taskSetManager110 = createTaskSetManager(6, 5, taskScheduler)
    +    val taskSetManager111 = createTaskSetManager(7, 5, taskScheduler)
    +    pool11.addSchedulable(taskSetManager110)
    +    pool11.addSchedulable(taskSetManager111)
    +
    +    scheduleTaskAndVerifyId(0, rootPool, 0)
    +    scheduleTaskAndVerifyId(1, rootPool, 4)
    +    scheduleTaskAndVerifyId(2, rootPool, 6)
    +    scheduleTaskAndVerifyId(3, rootPool, 2)
    +  }
    +}
    diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    index 00812e6018d1..8874cf00e999 100644
    --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    @@ -30,238 +30,8 @@ class FakeSchedulerBackend extends SchedulerBackend {
       def defaultParallelism() = 1
     }
     
    -class FakeTaskSetManager(
    -    initPriority: Int,
    -    initStageId: Int,
    -    initNumTasks: Int,
    -    taskScheduler: TaskSchedulerImpl,
    -    taskSet: TaskSet)
    -  extends TaskSetManager(taskScheduler, taskSet, 0) {
    -
    -  parent = null
    -  weight = 1
    -  minShare = 2
    -  priority = initPriority
    -  stageId = initStageId
    -  name = "TaskSet_"+stageId
    -  override val numTasks = initNumTasks
    -  tasksSuccessful = 0
    -
    -  var numRunningTasks = 0
    -  override def runningTasks = numRunningTasks
    -
    -  def increaseRunningTasks(taskNum: Int) {
    -    numRunningTasks += taskNum
    -    if (parent != null) {
    -      parent.increaseRunningTasks(taskNum)
    -    }
    -  }
    -
    -  def decreaseRunningTasks(taskNum: Int) {
    -    numRunningTasks -= taskNum
    -    if (parent != null) {
    -      parent.decreaseRunningTasks(taskNum)
    -    }
    -  }
    -
    -  override def addSchedulable(schedulable: Schedulable) {
    -  }
    -
    -  override def removeSchedulable(schedulable: Schedulable) {
    -  }
    -
    -  override def getSchedulableByName(name: String): Schedulable = {
    -    null
    -  }
    -
    -  override def executorLost(executorId: String, host: String): Unit = {
    -  }
    -
    -  override def resourceOffer(
    -      execId: String,
    -      host: String,
    -      maxLocality: TaskLocality.TaskLocality)
    -    : Option[TaskDescription] =
    -  {
    -    if (tasksSuccessful + numRunningTasks < numTasks) {
    -      increaseRunningTasks(1)
    -      Some(new TaskDescription(0, execId, "task 0:0", 0, null))
    -    } else {
    -      None
    -    }
    -  }
    -
    -  override def checkSpeculatableTasks(): Boolean = {
    -    true
    -  }
    -
    -  def taskFinished() {
    -    decreaseRunningTasks(1)
    -    tasksSuccessful +=1
    -    if (tasksSuccessful == numTasks) {
    -      parent.removeSchedulable(this)
    -    }
    -  }
    -
    -  def abort() {
    -    decreaseRunningTasks(numRunningTasks)
    -    parent.removeSchedulable(this)
    -  }
    -}
    -
     class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging {
     
    -  def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl,
    -      taskSet: TaskSet): FakeTaskSetManager = {
    -    new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
    -  }
    -
    -  def resourceOffer(rootPool: Pool): Int = {
    -    val taskSetQueue = rootPool.getSortedTaskSetQueue
    -    /* Just for Test*/
    -    for (manager <- taskSetQueue) {
    -       logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
    -         manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
    -    }
    -    for (taskSet <- taskSetQueue) {
    -      taskSet.resourceOffer("execId_1", "hostname_1", TaskLocality.ANY) match {
    -        case Some(task) =>
    -          return taskSet.stageId
    -        case None => {}
    -      }
    -    }
    -    -1
    -  }
    -
    -  def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) {
    -    assert(resourceOffer(rootPool) === expectedTaskSetId)
    -  }
    -
    -  test("FIFO Scheduler Test") {
    -    sc = new SparkContext("local", "TaskSchedulerImplSuite")
    -    val taskScheduler = new TaskSchedulerImpl(sc)
    -    val taskSet = FakeTask.createTaskSet(1)
    -
    -    val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
    -    val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
    -    schedulableBuilder.buildPools()
    -
    -    val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet)
    -    val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet)
    -    val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet)
    -    schedulableBuilder.addTaskSetManager(taskSetManager0, null)
    -    schedulableBuilder.addTaskSetManager(taskSetManager1, null)
    -    schedulableBuilder.addTaskSetManager(taskSetManager2, null)
    -
    -    checkTaskSetId(rootPool, 0)
    -    resourceOffer(rootPool)
    -    checkTaskSetId(rootPool, 1)
    -    resourceOffer(rootPool)
    -    taskSetManager1.abort()
    -    checkTaskSetId(rootPool, 2)
    -  }
    -
    -  test("Fair Scheduler Test") {
    -    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
    -    val conf = new SparkConf().set("spark.scheduler.allocation.file", xmlPath)
    -    sc = new SparkContext("local", "TaskSchedulerImplSuite", conf)
    -    val taskScheduler = new TaskSchedulerImpl(sc)
    -    val taskSet = FakeTask.createTaskSet(1)
    -
    -    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
    -    val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf)
    -    schedulableBuilder.buildPools()
    -
    -    assert(rootPool.getSchedulableByName("default") != null)
    -    assert(rootPool.getSchedulableByName("1") != null)
    -    assert(rootPool.getSchedulableByName("2") != null)
    -    assert(rootPool.getSchedulableByName("3") != null)
    -    assert(rootPool.getSchedulableByName("1").minShare === 2)
    -    assert(rootPool.getSchedulableByName("1").weight === 1)
    -    assert(rootPool.getSchedulableByName("2").minShare === 3)
    -    assert(rootPool.getSchedulableByName("2").weight === 1)
    -    assert(rootPool.getSchedulableByName("3").minShare === 0)
    -    assert(rootPool.getSchedulableByName("3").weight === 1)
    -
    -    val properties1 = new Properties()
    -    properties1.setProperty("spark.scheduler.pool","1")
    -    val properties2 = new Properties()
    -    properties2.setProperty("spark.scheduler.pool","2")
    -
    -    val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet)
    -    val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet)
    -    val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet)
    -    schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
    -    schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
    -    schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
    -
    -    val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet)
    -    val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet)
    -    schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
    -    schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
    -
    -    checkTaskSetId(rootPool, 0)
    -    checkTaskSetId(rootPool, 3)
    -    checkTaskSetId(rootPool, 3)
    -    checkTaskSetId(rootPool, 1)
    -    checkTaskSetId(rootPool, 4)
    -    checkTaskSetId(rootPool, 2)
    -    checkTaskSetId(rootPool, 2)
    -    checkTaskSetId(rootPool, 4)
    -
    -    taskSetManager12.taskFinished()
    -    assert(rootPool.getSchedulableByName("1").runningTasks === 3)
    -    taskSetManager24.abort()
    -    assert(rootPool.getSchedulableByName("2").runningTasks === 2)
    -  }
    -
    -  test("Nested Pool Test") {
    -    sc = new SparkContext("local", "TaskSchedulerImplSuite")
    -    val taskScheduler = new TaskSchedulerImpl(sc)
    -    val taskSet = FakeTask.createTaskSet(1)
    -
    -    val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
    -    val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
    -    val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
    -    rootPool.addSchedulable(pool0)
    -    rootPool.addSchedulable(pool1)
    -
    -    val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
    -    val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
    -    pool0.addSchedulable(pool00)
    -    pool0.addSchedulable(pool01)
    -
    -    val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
    -    val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
    -    pool1.addSchedulable(pool10)
    -    pool1.addSchedulable(pool11)
    -
    -    val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet)
    -    val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet)
    -    pool00.addSchedulable(taskSetManager000)
    -    pool00.addSchedulable(taskSetManager001)
    -
    -    val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet)
    -    val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet)
    -    pool01.addSchedulable(taskSetManager010)
    -    pool01.addSchedulable(taskSetManager011)
    -
    -    val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet)
    -    val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet)
    -    pool10.addSchedulable(taskSetManager100)
    -    pool10.addSchedulable(taskSetManager101)
    -
    -    val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet)
    -    val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet)
    -    pool11.addSchedulable(taskSetManager110)
    -    pool11.addSchedulable(taskSetManager111)
    -
    -    checkTaskSetId(rootPool, 0)
    -    checkTaskSetId(rootPool, 4)
    -    checkTaskSetId(rootPool, 6)
    -    checkTaskSetId(rootPool, 2)
    -  }
    -
       test("Scheduler does not always schedule tasks on the same workers") {
         sc = new SparkContext("local", "TaskSchedulerImplSuite")
         val taskScheduler = new TaskSchedulerImpl(sc)
    
    From e9ca16ec943b9553056482d0c085eacb6046821e Mon Sep 17 00:00:00 2001
    From: Liang-Chi Hsieh 
    Date: Fri, 9 Jan 2015 10:27:33 -0800
    Subject: [PATCH 097/215] [SPARK-5145][Mllib] Add BLAS.dsyr and use it in
     GaussianMixtureEM
    
    This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM.
    
    Author: Liang-Chi Hsieh 
    
    Closes #3949 from viirya/blas_dsyr and squashes the following commits:
    
    4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style.
    3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM.
    ---
     .../mllib/clustering/GaussianMixtureEM.scala  | 10 +++--
     .../org/apache/spark/mllib/linalg/BLAS.scala  | 26 ++++++++++++
     .../apache/spark/mllib/linalg/BLASSuite.scala | 41 +++++++++++++++++++
     3 files changed, 73 insertions(+), 4 deletions(-)
    
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    index bdf984aee4da..3a6c0e681e3f 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq
     
     import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
    +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
     import org.apache.spark.mllib.stat.impl.MultivariateGaussian
     import org.apache.spark.mllib.util.MLUtils
     
    @@ -151,9 +151,10 @@ class GaussianMixtureEM private (
           var i = 0
           while (i < k) {
             val mu = sums.means(i) / sums.weights(i)
    -        val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
    +        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
    +          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
             weights(i) = sums.weights(i) / sumWeights
    -        gaussians(i) = new MultivariateGaussian(mu, sigma)
    +        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
             i = i + 1
           }
        
    @@ -211,7 +212,8 @@ private object ExpectationSum {
           p(i) /= pSum
           sums.weights(i) += p(i)
           sums.means(i) += x * p(i)
    -      sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
    +      BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
    +        Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
           i = i + 1
         }
         sums
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    index 9fed513becdd..3414daccd7ca 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
    @@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging {
         }
         _nativeBLAS
       }
    + 
    +  /**
    +   * A := alpha * x * x^T^ + A
    +   * @param alpha a real scalar that will be multiplied to x * x^T^.
    +   * @param x the vector x that contains the n elements.
    +   * @param A the symmetric matrix A. Size of n x n.
    +   */
    +  def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
    +    val mA = A.numRows
    +    val nA = A.numCols
    +    require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
    +    require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
    +
    +    nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
    +
    +    // Fill lower triangular part of A
    +    var i = 0
    +    while (i < mA) {
    +      var j = i + 1
    +      while (j < nA) {
    +        A(j, i) = A(i, j)
    +        j += 1
    +      }
    +      i += 1
    +    }    
    +  }
     
       /**
        * C := alpha * A * B + beta * C
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    index 5d70c914f14b..771878e925ea 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
    @@ -127,6 +127,47 @@ class BLASSuite extends FunSuite {
         }
       }
     
    +  test("syr") {
    +    val dA = new DenseMatrix(4, 4,
    +      Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
    +    val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1))
    +    val alpha = 0.15
    +
    +    val expected = new DenseMatrix(4, 4,
    +      Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1,
    +        5.4505, 4.1025, 1.4615))
    +
    +    syr(alpha, x, dA)
    +
    +    assert(dA ~== expected absTol 1e-15)
    + 
    +    val dB =
    +      new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
    +
    +    withClue("Matrix A must be a symmetric Matrix") {
    +      intercept[Exception] {
    +        syr(alpha, x, dB)
    +      }
    +    }
    + 
    +    val dC =
    +      new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
    +
    +    withClue("Size of vector must match the rank of matrix") {
    +      intercept[Exception] {
    +        syr(alpha, x, dC)
    +      }
    +    }
    + 
    +    val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
    +
    +    withClue("Size of vector must match the rank of matrix") {
    +      intercept[Exception] {
    +        syr(alpha, y, dA)
    +      }
    +    }
    +  }
    +
       test("gemm") {
     
         val dA =
    
    From 454fe129ee97b859bf079db8b9158e115a219ad5 Mon Sep 17 00:00:00 2001
    From: Jongyoul Lee 
    Date: Fri, 9 Jan 2015 10:47:08 -0800
    Subject: [PATCH 098/215] [SPARK-3619] Upgrade to Mesos 0.21 to work around
     MESOS-1688
    
    - update version from 0.18.1 to 0.21.0
    - I'm doing some tests in order to verify some spark jobs work fine on mesos 0.21.0 environment.
    
    Author: Jongyoul Lee 
    
    Closes #3934 from jongyoul/SPARK-3619 and squashes the following commits:
    
    ab994fa [Jongyoul Lee] [SPARK-3619] Upgrade to Mesos 0.21 to work around MESOS-1688 - update version from 0.18.1 to 0.21.0
    ---
     pom.xml | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/pom.xml b/pom.xml
    index 703e5c47bf59..aadcdfd1083c 100644
    --- a/pom.xml
    +++ b/pom.xml
    @@ -115,7 +115,7 @@
         1.6
         spark
         2.0.1
    -    0.18.1
    +    0.21.0
         shaded-protobuf
         1.7.5
         1.2.17
    
    From 7e8e62aec11c43c983055adc475b96006412199a Mon Sep 17 00:00:00 2001
    From: "Joseph K. Bradley" 
    Date: Fri, 9 Jan 2015 13:00:15 -0800
    Subject: [PATCH 099/215] [SPARK-5015] [mllib] Random seed for GMM + make test
     suite deterministic
    
    Issues:
    * From JIRA: GaussianMixtureEM uses randomness but does not take a random seed. It should take one as a parameter.
    * This also makes the test suite flaky since initialization can fail due to stochasticity.
    
    Fix:
    * Add random seed
    * Use it in test suite
    
    CC: mengxr  tgaloppo
    
    Author: Joseph K. Bradley 
    
    Closes #3981 from jkbradley/gmm-seed and squashes the following commits:
    
    f0df4fd [Joseph K. Bradley] Added seed parameter to GMM.  Updated test suite to use seed to prevent flakiness
    ---
     .../mllib/clustering/GaussianMixtureEM.scala  | 26 ++++++++++++++-----
     .../GMMExpectationMaximizationSuite.scala     | 14 +++++-----
     2 files changed, 27 insertions(+), 13 deletions(-)
    
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    index 3a6c0e681e3f..b3c5631cc4cc 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
    @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
     import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
     import org.apache.spark.mllib.stat.impl.MultivariateGaussian
     import org.apache.spark.mllib.util.MLUtils
    +import org.apache.spark.util.Utils
     
     /**
      * This class performs expectation maximization for multivariate Gaussian
    @@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
     class GaussianMixtureEM private (
         private var k: Int, 
         private var convergenceTol: Double, 
    -    private var maxIterations: Int) extends Serializable {
    +    private var maxIterations: Int,
    +    private var seed: Long) extends Serializable {
       
       /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
    -  def this() = this(2, 0.01, 100)
    +  def this() = this(2, 0.01, 100, Utils.random.nextLong())
       
       // number of samples per cluster to use when initializing Gaussians
       private val nSamples = 5
    @@ -100,11 +102,21 @@ class GaussianMixtureEM private (
         this
       }
       
    -  /** Return the largest change in log-likelihood at which convergence is
    -   *  considered to have occurred.
    +  /**
    +   * Return the largest change in log-likelihood at which convergence is
    +   * considered to have occurred.
        */
       def getConvergenceTol: Double = convergenceTol
    -  
    +
    +  /** Set the random seed */
    +  def setSeed(seed: Long): this.type = {
    +    this.seed = seed
    +    this
    +  }
    +
    +  /** Return the random seed */
    +  def getSeed: Long = seed
    +
       /** Perform expectation maximization */
       def run(data: RDD[Vector]): GaussianMixtureModel = {
         val sc = data.sparkContext
    @@ -113,7 +125,7 @@ class GaussianMixtureEM private (
         val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
         
         // Get length of the input vectors
    -    val d = breezeData.first.length 
    +    val d = breezeData.first().length
         
         // Determine initial weights and corresponding Gaussians.
         // If the user supplied an initial GMM, we use those values, otherwise
    @@ -126,7 +138,7 @@ class GaussianMixtureEM private (
           })
           
           case None => {
    -        val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
    +        val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
             (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => 
               val slice = samples.view(i * nSamples, (i + 1) * nSamples)
               new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) 
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
    index 23feb82874b7..9da5495741a8 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
    @@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
         val Ew = 1.0
         val Emu = Vectors.dense(5.0, 10.0)
         val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
    -    
    -    val gmm = new GaussianMixtureEM().setK(1).run(data)
    -                
    -    assert(gmm.weight(0) ~== Ew absTol 1E-5)
    -    assert(gmm.mu(0) ~== Emu absTol 1E-5)
    -    assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
    +
    +    val seeds = Array(314589, 29032897, 50181, 494821, 4660)
    +    seeds.foreach { seed =>
    +      val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
    +      assert(gmm.weight(0) ~== Ew absTol 1E-5)
    +      assert(gmm.mu(0) ~== Emu absTol 1E-5)
    +      assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
    +    }
       }
       
       test("two clusters") {
    
    From e96645206006a009e5c1a23bbd177dcaf3ef9b83 Mon Sep 17 00:00:00 2001
    From: WangTaoTheTonic 
    Date: Fri, 9 Jan 2015 13:20:32 -0800
    Subject: [PATCH 100/215] [SPARK-1953][YARN]yarn client mode Application Master
     memory size is same as driver memory...
    
    ... size
    
    Ways to set Application Master's memory on yarn-client mode:
    1.  `spark.yarn.am.memory` in SparkConf or System Properties
    2.  default value 512m
    
    Note: this arguments is only available in yarn-client mode.
    
    Author: WangTaoTheTonic 
    
    Closes #3607 from WangTaoTheTonic/SPARK4181 and squashes the following commits:
    
    d5ceb1b [WangTaoTheTonic] spark.driver.memeory is used in both modes
    6c1b264 [WangTaoTheTonic] rebase
    b8410c0 [WangTaoTheTonic] minor optiminzation
    ddcd592 [WangTaoTheTonic] fix the bug produced in rebase and some improvements
    3bf70cc [WangTaoTheTonic] rebase and give proper hint
    987b99d [WangTaoTheTonic] disable --driver-memory in client mode
    2b27928 [WangTaoTheTonic] inaccurate description
    b7acbb2 [WangTaoTheTonic] incorrect method invoked
    2557c5e [WangTaoTheTonic] missing a single blank
    42075b0 [WangTaoTheTonic] arrange the args and warn logging
    69c7dba [WangTaoTheTonic] rebase
    1960d16 [WangTaoTheTonic] fix wrong comment
    7fa9e2e [WangTaoTheTonic] log a warning
    f6bee0e [WangTaoTheTonic] docs issue
    d619996 [WangTaoTheTonic] Merge branch 'master' into SPARK4181
    b09c309 [WangTaoTheTonic] use code format
    ab16bb5 [WangTaoTheTonic] fix bug and add comments
    44e48c2 [WangTaoTheTonic] minor fix
    6fd13e1 [WangTaoTheTonic] add overhead mem and remove some configs
    0566bb8 [WangTaoTheTonic] yarn client mode Application Master memory size is same as driver memory size
    ---
     .../spark/deploy/SparkSubmitArguments.scala   |  3 +-
     docs/running-on-yarn.md                       | 19 +++++++++-
     .../org/apache/spark/deploy/yarn/Client.scala |  2 +-
     .../spark/deploy/yarn/ClientArguments.scala   | 37 ++++++++++++++-----
     .../cluster/YarnClientSchedulerBackend.scala  |  2 -
     5 files changed, 48 insertions(+), 15 deletions(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
    index 1faabe91f49a..f14ef4d29938 100644
    --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
    +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
    @@ -405,7 +405,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
             |  --queue QUEUE_NAME          The YARN queue to submit to (Default: "default").
             |  --num-executors NUM         Number of executors to launch (Default: 2).
             |  --archives ARCHIVES         Comma separated list of archives to be extracted into the
    -        |                              working directory of each executor.""".stripMargin
    +        |                              working directory of each executor.
    +      """.stripMargin
         )
         SparkSubmit.exitFn()
       }
    diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
    index 183698ffe930..4f273098c5db 100644
    --- a/docs/running-on-yarn.md
    +++ b/docs/running-on-yarn.md
    @@ -21,6 +21,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
     
     
    +
    +  
    +  
    +  
    +
    @@ -90,7 +98,14 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
       
    +
    +
    +  
    +  
    +  
    @@ -145,7 +160,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
       
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    index c363d755c175..032106371cd6 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    @@ -65,7 +65,7 @@ private[spark] class Client(
       private val amMemoryOverhead = args.amMemoryOverhead // MB
       private val executorMemoryOverhead = args.executorMemoryOverhead // MB
       private val distCacheMgr = new ClientDistributedCacheManager()
    -  private val isClusterMode = args.userClass != null
    +  private val isClusterMode = args.isClusterMode
     
     
       def stop(): Unit = yarnClient.stop()
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    index 39f1021c9d94..fdbf9f8eed02 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    @@ -38,23 +38,27 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
       var amMemory: Int = 512 // MB
       var appName: String = "Spark"
       var priority = 0
    +  def isClusterMode: Boolean = userClass != null
    +
    +  private var driverMemory: Int = 512 // MB
    +  private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead"
    +  private val amMemKey = "spark.yarn.am.memory"
    +  private val amMemOverheadKey = "spark.yarn.am.memoryOverhead"
    +  private val isDynamicAllocationEnabled =
    +    sparkConf.getBoolean("spark.dynamicAllocation.enabled", false)
     
       parseArgs(args.toList)
    +  loadEnvironmentArgs()
    +  validateArgs()
     
       // Additional memory to allocate to containers
    -  // For now, use driver's memory overhead as our AM container's memory overhead
    -  val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead",
    +  val amMemoryOverheadConf = if (isClusterMode) driverMemOverheadKey else amMemOverheadKey
    +  val amMemoryOverhead = sparkConf.getInt(amMemoryOverheadConf,
         math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN))
     
       val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead",
         math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN))
     
    -  private val isDynamicAllocationEnabled =
    -    sparkConf.getBoolean("spark.dynamicAllocation.enabled", false)
    -
    -  loadEnvironmentArgs()
    -  validateArgs()
    -
       /** Load any default arguments provided through environment variables and Spark properties. */
       private def loadEnvironmentArgs(): Unit = {
         // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://,
    @@ -87,6 +91,21 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
           throw new IllegalArgumentException(
             "You must specify at least 1 executor!\n" + getUsageMessage())
         }
    +    if (isClusterMode) {
    +      for (key <- Seq(amMemKey, amMemOverheadKey)) {
    +        if (sparkConf.contains(key)) {
    +          println(s"$key is set but does not apply in cluster mode.")
    +        }
    +      }
    +      amMemory = driverMemory
    +    } else {
    +      if (sparkConf.contains(driverMemOverheadKey)) {
    +        println(s"$driverMemOverheadKey is set but does not apply in client mode.")
    +      }
    +      sparkConf.getOption(amMemKey)
    +        .map(Utils.memoryStringToMb)
    +        .foreach { mem => amMemory = mem }
    +    }
       }
     
       private def parseArgs(inputArgs: List[String]): Unit = {
    @@ -118,7 +137,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
               if (args(0) == "--master-memory") {
                 println("--master-memory is deprecated. Use --driver-memory instead.")
               }
    -          amMemory = value
    +          driverMemory = value
               args = tail
     
             case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail =>
    diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
    index 09597bd0e6ab..f99291553b7b 100644
    --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
    +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
    @@ -68,8 +68,6 @@ private[spark] class YarnClientSchedulerBackend(
         // List of (target Client argument, environment variable, Spark property)
         val optionTuples =
           List(
    -        ("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"),
    -        ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"),
             ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"),
             ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"),
             ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"),
    
    From e0f28e010cdd67a2a4c8aebd35323d69a3182ba8 Mon Sep 17 00:00:00 2001
    From: mcheah 
    Date: Fri, 9 Jan 2015 14:16:20 -0800
    Subject: [PATCH 101/215] [SPARK-4737] Task set manager properly handles
     serialization errors
    
    Dealing with [SPARK-4737], the handling of serialization errors should not be the DAGScheduler's responsibility. The task set manager now catches the error and aborts the stage.
    
    If the TaskSetManager throws a TaskNotSerializableException, the TaskSchedulerImpl will return an empty list of task descriptions, because no tasks were started. The scheduler should abort the stage gracefully.
    
    Note that I'm not too familiar with this part of the codebase and its place in the overall architecture of the Spark stack. If implementing it this way will have any averse side effects please voice that loudly.
    
    Author: mcheah 
    
    Closes #3638 from mccheah/task-set-manager-properly-handle-ser-err and squashes the following commits:
    
    1545984 [mcheah] Some more style fixes from Andrew Or.
    5267929 [mcheah] Fixing style suggestions from Andrew Or.
    dfa145b [mcheah] Fixing style from Josh Rosen's feedback
    b2a430d [mcheah] Not returning empty seq when a task set cannot be serialized.
    94844d7 [mcheah] Fixing compilation error, one brace too many
    5f486f4 [mcheah] Adding license header for fake task class
    bf5e706 [mcheah] Fixing indentation.
    097e7a2 [mcheah] [SPARK-4737] Catching task serialization exception in TaskSetManager
    ---
     .../spark/TaskNotSerializableException.scala  | 25 +++++++++
     .../apache/spark/scheduler/DAGScheduler.scala | 20 -------
     .../spark/scheduler/TaskSchedulerImpl.scala   | 54 +++++++++++++------
     .../spark/scheduler/TaskSetManager.scala      | 18 +++++--
     .../org/apache/spark/SharedSparkContext.scala |  2 +-
     .../scala/org/apache/spark/rdd/RDDSuite.scala | 21 ++++++++
     .../scheduler/NotSerializableFakeTask.scala   | 40 ++++++++++++++
     .../scheduler/TaskSchedulerImplSuite.scala    | 30 +++++++++++
     .../spark/scheduler/TaskSetManagerSuite.scala | 14 +++++
     9 files changed, 182 insertions(+), 42 deletions(-)
     create mode 100644 core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
     create mode 100644 core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
    
    diff --git a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
    new file mode 100644
    index 000000000000..9df61062e1f8
    --- /dev/null
    +++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
    @@ -0,0 +1,25 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark
    +
    +import org.apache.spark.annotation.DeveloperApi
    +
    +/**
    + * Exception thrown when a task cannot be serialized.
    + */
    +private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error)
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    index 259621d263d7..61d09d73e17c 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    @@ -866,26 +866,6 @@ class DAGScheduler(
         }
     
         if (tasks.size > 0) {
    -      // Preemptively serialize a task to make sure it can be serialized. We are catching this
    -      // exception here because it would be fairly hard to catch the non-serializable exception
    -      // down the road, where we have several different implementations for local scheduler and
    -      // cluster schedulers.
    -      //
    -      // We've already serialized RDDs and closures in taskBinary, but here we check for all other
    -      // objects such as Partition.
    -      try {
    -        closureSerializer.serialize(tasks.head)
    -      } catch {
    -        case e: NotSerializableException =>
    -          abortStage(stage, "Task not serializable: " + e.toString)
    -          runningStages -= stage
    -          return
    -        case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo.
    -          abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
    -          runningStages -= stage
    -          return
    -      }
    -
           logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
           stage.pendingTasks ++= tasks
           logDebug("New pending tasks: " + stage.pendingTasks)
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
    index a41f3eef195d..a1dfb0106259 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
    @@ -31,6 +31,7 @@ import scala.util.Random
     import org.apache.spark._
     import org.apache.spark.TaskState.TaskState
     import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
    +import org.apache.spark.scheduler.TaskLocality.TaskLocality
     import org.apache.spark.util.Utils
     import org.apache.spark.executor.TaskMetrics
     import org.apache.spark.storage.BlockManagerId
    @@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl(
           .format(manager.taskSet.id, manager.parent.name))
       }
     
    +  private def resourceOfferSingleTaskSet(
    +      taskSet: TaskSetManager,
    +      maxLocality: TaskLocality,
    +      shuffledOffers: Seq[WorkerOffer],
    +      availableCpus: Array[Int],
    +      tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
    +    var launchedTask = false
    +    for (i <- 0 until shuffledOffers.size) {
    +      val execId = shuffledOffers(i).executorId
    +      val host = shuffledOffers(i).host
    +      if (availableCpus(i) >= CPUS_PER_TASK) {
    +        try {
    +          for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
    +            tasks(i) += task
    +            val tid = task.taskId
    +            taskIdToTaskSetId(tid) = taskSet.taskSet.id
    +            taskIdToExecutorId(tid) = execId
    +            executorsByHost(host) += execId
    +            availableCpus(i) -= CPUS_PER_TASK
    +            assert(availableCpus(i) >= 0)
    +            launchedTask = true
    +          }
    +        } catch {
    +          case e: TaskNotSerializableException =>
    +            logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
    +            // Do not offer resources for this task, but don't throw an error to allow other
    +            // task sets to be submitted.
    +            return launchedTask
    +        }
    +      }
    +    }
    +    return launchedTask
    +  }
    +
       /**
        * Called by cluster manager to offer resources on slaves. We respond by asking our active task
        * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
    @@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl(
         var launchedTask = false
         for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
           do {
    -        launchedTask = false
    -        for (i <- 0 until shuffledOffers.size) {
    -          val execId = shuffledOffers(i).executorId
    -          val host = shuffledOffers(i).host
    -          if (availableCpus(i) >= CPUS_PER_TASK) {
    -            for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
    -              tasks(i) += task
    -              val tid = task.taskId
    -              taskIdToTaskSetId(tid) = taskSet.taskSet.id
    -              taskIdToExecutorId(tid) = execId
    -              executorsByHost(host) += execId
    -              availableCpus(i) -= CPUS_PER_TASK
    -              assert(availableCpus(i) >= 0)
    -              launchedTask = true
    -            }
    -          }
    -        }
    +        launchedTask = resourceOfferSingleTaskSet(
    +            taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
           } while (launchedTask)
         }
     
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
    index 28e6147509f7..466785091715 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
    @@ -18,12 +18,14 @@
     package org.apache.spark.scheduler
     
     import java.io.NotSerializableException
    +import java.nio.ByteBuffer
     import java.util.Arrays
     
     import scala.collection.mutable.ArrayBuffer
     import scala.collection.mutable.HashMap
     import scala.collection.mutable.HashSet
     import scala.math.{min, max}
    +import scala.util.control.NonFatal
     
     import org.apache.spark._
     import org.apache.spark.executor.TaskMetrics
    @@ -417,6 +419,7 @@ private[spark] class TaskSetManager(
        * @param host  the host Id of the offered resource
        * @param maxLocality the maximum locality we want to schedule the tasks at
        */
    +  @throws[TaskNotSerializableException]
       def resourceOffer(
           execId: String,
           host: String,
    @@ -456,10 +459,17 @@ private[spark] class TaskSetManager(
               }
               // Serialize and return the task
               val startTime = clock.getTime()
    -          // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
    -          // we assume the task can be serialized without exceptions.
    -          val serializedTask = Task.serializeWithDependencies(
    -            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
    +          val serializedTask: ByteBuffer = try {
    +            Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
    +          } catch {
    +            // If the task cannot be serialized, then there's no point to re-attempt the task,
    +            // as it will always fail. So just abort the whole task-set.
    +            case NonFatal(e) =>
    +              val msg = s"Failed to serialize task $taskId, not attempting to retry it."
    +              logError(msg, e)
    +              abort(s"$msg Exception during serialization: $e")
    +              throw new TaskNotSerializableException(e)
    +          }
               if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
                   !emittedTaskSizeWarning) {
                 emittedTaskSizeWarning = true
    diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
    index 0b6511a80df1..3d2700b7e6be 100644
    --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
    +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
    @@ -30,7 +30,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
       var conf = new SparkConf(false)
     
       override def beforeAll() {
    -    _sc = new SparkContext("local", "test", conf)
    +    _sc = new SparkContext("local[4]", "test", conf)
         super.beforeAll()
       }
     
    diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
    index 6836e9ab0fd6..0deb9b18b868 100644
    --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
    @@ -17,6 +17,10 @@
     
     package org.apache.spark.rdd
     
    +import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
    +
    +import com.esotericsoftware.kryo.KryoException
    +
     import scala.collection.mutable.{ArrayBuffer, HashMap}
     import scala.collection.JavaConverters._
     import scala.reflect.ClassTag
    @@ -887,6 +891,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
         assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3)
       }
     
    +  test("task serialization exception should not hang scheduler") {
    +    class BadSerializable extends Serializable {
    +      @throws(classOf[IOException])
    +      private def writeObject(out: ObjectOutputStream): Unit = throw new KryoException("Bad serialization")
    +
    +      @throws(classOf[IOException])
    +      private def readObject(in: ObjectInputStream): Unit = {}
    +    }
    +    // Note that in the original bug, SPARK-4349, that this verifies, the job would only hang if there were
    +    // more threads in the Spark Context than there were number of objects in this sequence.
    +    intercept[Throwable] {
    +      sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect
    +    }
    +    // Check that the context has not crashed
    +    sc.parallelize(1 to 100).map(x => x*2).collect
    +  }
    +
       /** A contrived RDD that allows the manual addition of dependencies after creation. */
       private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) {
         private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty
    diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
    new file mode 100644
    index 000000000000..6b75c98839e0
    --- /dev/null
    +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
    @@ -0,0 +1,40 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.scheduler
    +
    +import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
    +
    +import org.apache.spark.TaskContext
    +
    +/**
    + * A Task implementation that fails to serialize.
    + */
    +private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends Task[Array[Byte]](stageId, 0) {
    +  override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
    +  override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
    +
    +  @throws(classOf[IOException])
    +  private def writeObject(out: ObjectOutputStream): Unit = {
    +    if (stageId == 0) {
    +      throw new IllegalStateException("Cannot serialize")
    +    }
    +  }
    +
    +  @throws(classOf[IOException])
    +  private def readObject(in: ObjectInputStream): Unit = {}
    +}
    diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    index 8874cf00e999..add13f5b2176 100644
    --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
    @@ -100,4 +100,34 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
         assert(1 === taskDescriptions.length)
         assert("executor0" === taskDescriptions(0).executorId)
       }
    +
    +  test("Scheduler does not crash when tasks are not serializable") {
    +    sc = new SparkContext("local", "TaskSchedulerImplSuite")
    +    val taskCpus = 2
    +
    +    sc.conf.set("spark.task.cpus", taskCpus.toString)
    +    val taskScheduler = new TaskSchedulerImpl(sc)
    +    taskScheduler.initialize(new FakeSchedulerBackend)
    +    // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
    +    val dagScheduler = new DAGScheduler(sc, taskScheduler) {
    +      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
    +      override def executorAdded(execId: String, host: String) {}
    +    }
    +    val numFreeCores = 1
    +    taskScheduler.setDAGScheduler(dagScheduler)
    +    var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
    +    val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", taskCpus),
    +      new WorkerOffer("executor1", "host1", numFreeCores))
    +    taskScheduler.submitTasks(taskSet)
    +    var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
    +    assert(0 === taskDescriptions.length)
    +
    +    // Now check that we can still submit tasks
    +    // Even if one of the tasks has not-serializable tasks, the other task set should still be processed without error
    +    taskScheduler.submitTasks(taskSet)
    +    taskScheduler.submitTasks(FakeTask.createTaskSet(1))
    +    taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
    +    assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
    +  }
    +
     }
    diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
    index 472191551a01..84b9b788237b 100644
    --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
    @@ -17,6 +17,7 @@
     
     package org.apache.spark.scheduler
     
    +import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
     import java.util.Random
     
     import scala.collection.mutable.ArrayBuffer
    @@ -563,6 +564,19 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
         assert(manager.emittedTaskSizeWarning)
       }
     
    +  test("Not serializable exception thrown if the task cannot be serialized") {
    +    sc = new SparkContext("local", "test")
    +    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
    +
    +    val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
    +    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
    +
    +    intercept[TaskNotSerializableException] {
    +      manager.resourceOffer("exec1", "host1", ANY)
    +    }
    +    assert(manager.isZombie)
    +  }
    +
       test("abort the job if total size of results is too large") {
         val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
         sc = new SparkContext("local", "test", conf)
    
    From ae628725abce9ffe34b9a7110d5ac51a076454aa Mon Sep 17 00:00:00 2001
    From: Kousuke Saruta 
    Date: Fri, 9 Jan 2015 14:40:45 -0800
    Subject: [PATCH 102/215] [DOC] Fixed Mesos version in doc from 0.18.1 to
     0.21.0
    
    #3934 upgraded Mesos version so we should also fix docs right?
    
    This issue is really minor so I don't file in JIRA.
    
    Author: Kousuke Saruta 
    
    Closes #3982 from sarutak/fix-mesos-version and squashes the following commits:
    
    9a86ee3 [Kousuke Saruta] Fixed mesos version from 0.18.1 to 0.21.0
    ---
     docs/_config.yml | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/docs/_config.yml b/docs/_config.yml
    index a96a76dd9ab5..e2db274e1f61 100644
    --- a/docs/_config.yml
    +++ b/docs/_config.yml
    @@ -17,6 +17,6 @@ SPARK_VERSION: 1.3.0-SNAPSHOT
     SPARK_VERSION_SHORT: 1.3.0
     SCALA_BINARY_VERSION: "2.10"
     SCALA_VERSION: "2.10.4"
    -MESOS_VERSION: 0.18.1
    +MESOS_VERSION: 0.21.0
     SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
     SPARK_GITHUB_URL: https://github.com/apache/spark
    
    From 4e1f12d997426560226648d62ee17c90352613e7 Mon Sep 17 00:00:00 2001
    From: bilna 
    Date: Fri, 9 Jan 2015 14:45:28 -0800
    Subject: [PATCH 103/215] [Minor] Fix import order and other coding style
    
    fixed import order and other coding style
    
    Author: bilna 
    Author: Bilna P 
    
    Closes #3966 from Bilna/master and squashes the following commits:
    
    5e76f04 [bilna] fix import order and other coding style
    5718d66 [bilna] Merge remote-tracking branch 'upstream/master'
    ae56514 [bilna] Merge remote-tracking branch 'upstream/master'
    acea3a3 [bilna] Adding dependency with scope test
    28681fa [bilna] Merge remote-tracking branch 'upstream/master'
    fac3904 [bilna] Correction in Indentation and coding style
    ed9db4c [bilna] Merge remote-tracking branch 'upstream/master'
    4b34ee7 [Bilna P] Update MQTTStreamSuite.scala
    04503cf [bilna] Added embedded broker service for mqtt test
    89d804e [bilna] Merge remote-tracking branch 'upstream/master'
    fc8eb28 [bilna] Merge remote-tracking branch 'upstream/master'
    4b58094 [Bilna P] Update MQTTStreamSuite.scala
    b1ac4ad [bilna] Added BeforeAndAfter
    5f6bfd2 [bilna] Added BeforeAndAfter
    e8b6623 [Bilna P] Update MQTTStreamSuite.scala
    5ca6691 [Bilna P] Update MQTTStreamSuite.scala
    8616495 [bilna] [SPARK-4631] unit test for MQTT
    ---
     .../spark/streaming/mqtt/MQTTStreamSuite.scala  | 17 +++++++++++------
     1 file changed, 11 insertions(+), 6 deletions(-)
    
    diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    index 98fe6cb301f5..39eb8b183488 100644
    --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    @@ -19,16 +19,19 @@ package org.apache.spark.streaming.mqtt
     
     import java.net.{URI, ServerSocket}
     
    +import scala.concurrent.duration._
    +
     import org.apache.activemq.broker.{TransportConnector, BrokerService}
    -import org.apache.spark.util.Utils
    +import org.eclipse.paho.client.mqttv3._
    +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
    +
     import org.scalatest.{BeforeAndAfter, FunSuite}
     import org.scalatest.concurrent.Eventually
    -import scala.concurrent.duration._
    +
     import org.apache.spark.streaming.{Milliseconds, StreamingContext}
     import org.apache.spark.storage.StorageLevel
     import org.apache.spark.streaming.dstream.ReceiverInputDStream
    -import org.eclipse.paho.client.mqttv3._
    -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
    +import org.apache.spark.util.Utils
     
     class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
     
    @@ -38,8 +41,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
       private val freePort = findFreePort()
       private val brokerUri = "//localhost:" + freePort
       private val topic = "def"
    -  private var ssc: StreamingContext = _
       private val persistenceDir = Utils.createTempDir()
    +
    +  private var ssc: StreamingContext = _
       private var broker: BrokerService = _
       private var connector: TransportConnector = _
     
    @@ -115,8 +119,9 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
             val message: MqttMessage = new MqttMessage(data.getBytes("utf-8"))
             message.setQos(1)
             message.setRetained(true)
    -        for (i <- 0 to 100)
    +        for (i <- 0 to 100) {
               msgTopic.publish(message)
    +        }
           }
         } finally {
           client.disconnect()
    
    From 8782eb992f461502238c41ece3a3002efa67a792 Mon Sep 17 00:00:00 2001
    From: WangTaoTheTonic 
    Date: Fri, 9 Jan 2015 17:10:02 -0800
    Subject: [PATCH 104/215] [SPARK-4990][Deploy]to find default properties file,
     search SPARK_CONF_DIR first
    
    https://issues.apache.org/jira/browse/SPARK-4990
    
    Author: WangTaoTheTonic 
    Author: WangTao 
    
    Closes #3823 from WangTaoTheTonic/SPARK-4990 and squashes the following commits:
    
    133c43e [WangTao] Update spark-submit2.cmd
    b1ab402 [WangTao] Update spark-submit
    4cc7f34 [WangTaoTheTonic] rebase
    55300bc [WangTaoTheTonic] use export to make it global
    d8d3cb7 [WangTaoTheTonic] remove blank line
    07b9ebf [WangTaoTheTonic] check SPARK_CONF_DIR instead of checking properties file
    c5a85eb [WangTaoTheTonic] to find default properties file, search SPARK_CONF_DIR first
    ---
     bin/spark-submit      | 5 ++++-
     bin/spark-submit2.cmd | 6 +++++-
     2 files changed, 9 insertions(+), 2 deletions(-)
    
    diff --git a/bin/spark-submit b/bin/spark-submit
    index aefd38a0a2b9..3e5cbdbb2439 100755
    --- a/bin/spark-submit
    +++ b/bin/spark-submit
    @@ -44,7 +44,10 @@ while (($#)); do
       shift
     done
     
    -DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf"
    +if [ -z "$SPARK_CONF_DIR" ]; then
    +  export SPARK_CONF_DIR="$SPARK_HOME/conf"
    +fi
    +DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf"
     if [ "$MASTER" == "yarn-cluster" ]; then
       SPARK_SUBMIT_DEPLOY_MODE=cluster
     fi
    diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
    index daf0284db923..12244a9cb04f 100644
    --- a/bin/spark-submit2.cmd
    +++ b/bin/spark-submit2.cmd
    @@ -24,7 +24,11 @@ set ORIG_ARGS=%*
     
     rem Reset the values of all variables used
     set SPARK_SUBMIT_DEPLOY_MODE=client
    -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
    +
    +if not defined %SPARK_CONF_DIR% (
    +  set SPARK_CONF_DIR=%SPARK_HOME%\conf
    +)
    +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf
     set SPARK_SUBMIT_DRIVER_MEMORY=
     set SPARK_SUBMIT_LIBRARY_PATH=
     set SPARK_SUBMIT_CLASSPATH=
    
    From 4554529dce8fe8ca937d887109ef072eef52bf51 Mon Sep 17 00:00:00 2001
    From: MechCoder 
    Date: Fri, 9 Jan 2015 17:45:18 -0800
    Subject: [PATCH 105/215] [SPARK-4406] [MLib] FIX: Validate k in SVD
    
    Raise exception when k is non-positive in SVD
    
    Author: MechCoder 
    
    Closes #3945 from MechCoder/spark-4406 and squashes the following commits:
    
    64e6d2d [MechCoder] TST: Add better test errors and messages
    12dae73 [MechCoder] [SPARK-4406] FIX: Validate k in SVD
    ---
     .../spark/mllib/linalg/distributed/IndexedRowMatrix.scala | 3 +++
     .../apache/spark/mllib/linalg/distributed/RowMatrix.scala | 2 +-
     .../mllib/linalg/distributed/IndexedRowMatrixSuite.scala  | 7 +++++++
     .../spark/mllib/linalg/distributed/RowMatrixSuite.scala   | 8 ++++++++
     4 files changed, 19 insertions(+), 1 deletion(-)
    
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    index 36d8cadd2bdd..181f50751648 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
    @@ -102,6 +102,9 @@ class IndexedRowMatrix(
           k: Int,
           computeU: Boolean = false,
           rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
    +
    +    val n = numCols().toInt
    +    require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
         val indices = rows.map(_.index)
         val svd = toRowMatrix().computeSVD(k, computeU, rCond)
         val U = if (computeU) {
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    index fbd35e372f9b..d5abba6a4b64 100644
    --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
    @@ -212,7 +212,7 @@ class RowMatrix(
           tol: Double,
           mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
         val n = numCols().toInt
    -    require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.")
    +    require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.")
     
         object SVDMode extends Enumeration {
           val LocalARPACK, LocalLAPACK, DistARPACK = Value
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    index e25bc02b06c9..741cd4997b85 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
    @@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
         assert(closeToZero(U * brzDiag(s) * V.t - localA))
       }
     
    +  test("validate k in svd") {
    +    val A = new IndexedRowMatrix(indexedRows)
    +    intercept[IllegalArgumentException] {
    +      A.computeSVD(-1)
    +    }
    +  }
    +
       def closeToZero(G: BDM[Double]): Boolean = {
         G.valuesIterator.map(math.abs).sum < 1e-6
       }
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    index dbf55ff81ca9..3309713e91f8 100644
    --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
    @@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
         }
       }
     
    +  test("validate k in svd") {
    +    for (mat <- Seq(denseMat, sparseMat)) {
    +      intercept[IllegalArgumentException] {
    +        mat.computeSVD(-1)
    +      }
    +    }
    +  }
    +
       def closeToZero(G: BDM[Double]): Boolean = {
         G.valuesIterator.map(math.abs).sum < 1e-6
       }
    
    From 545dfcb92f2ff82a51877d35d9669094ea81f466 Mon Sep 17 00:00:00 2001
    From: luogankun 
    Date: Fri, 9 Jan 2015 20:38:41 -0800
    Subject: [PATCH 106/215] [SPARK-5141][SQL]CaseInsensitiveMap throws
     java.io.NotSerializableException
    
    CaseInsensitiveMap throws java.io.NotSerializableException.
    
    Author: luogankun 
    
    Closes #3944 from luogankun/SPARK-5141 and squashes the following commits:
    
    b6d63d5 [luogankun] [SPARK-5141]CaseInsensitiveMap throws java.io.NotSerializableException
    ---
     sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala | 3 ++-
     1 file changed, 2 insertions(+), 1 deletion(-)
    
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    index 8a66ac31f2df..364bacec83b9 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    @@ -110,7 +110,8 @@ private[sql] case class CreateTableUsing(
     /**
      * Builds a map in which keys are case insensitive
      */
    -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] {
    +protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] 
    +  with Serializable {
     
       val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
     
    
    From 1e56eba5d906bef793dfd6f199db735a6116a764 Mon Sep 17 00:00:00 2001
    From: Alex Liu 
    Date: Sat, 10 Jan 2015 13:19:12 -0800
    Subject: [PATCH 107/215] [SPARK-4925][SQL] Publish Spark SQL hive-thriftserver
     maven artifact
    
    Author: Alex Liu 
    
    Closes #3766 from alexliu68/SPARK-SQL-4925 and squashes the following commits:
    
    3137b51 [Alex Liu] [SPARK-4925][SQL] Remove sql/hive-thriftserver module from pom.xml
    15f2e38 [Alex Liu] [SPARK-4925][SQL] Publish Spark SQL hive-thriftserver maven artifact
    ---
     sql/hive-thriftserver/pom.xml | 7 -------
     1 file changed, 7 deletions(-)
    
    diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
    index 259eef0b80d0..123a1f629ab1 100644
    --- a/sql/hive-thriftserver/pom.xml
    +++ b/sql/hive-thriftserver/pom.xml
    @@ -76,13 +76,6 @@
               
             
           
    -      
    -        org.apache.maven.plugins
    -        maven-deploy-plugin
    -        
    -          true
    -        
    -      
         
       
     
    
    From 4b39fd1e63188821fc84a13f7ccb6e94277f4be7 Mon Sep 17 00:00:00 2001
    From: Alex Liu 
    Date: Sat, 10 Jan 2015 13:23:09 -0800
    Subject: [PATCH 108/215] [SPARK-4943][SQL] Allow table name having dot for
     db/catalog
    
    The pull only fixes the parsing error and changes API to use tableIdentifier. Joining different catalog datasource related change is not done in this pull.
    
    Author: Alex Liu 
    
    Closes #3941 from alexliu68/SPARK-SQL-4943-3 and squashes the following commits:
    
    343ae27 [Alex Liu] [SPARK-4943][SQL] refactoring according to review
    29e5e55 [Alex Liu] [SPARK-4943][SQL] fix failed Hive CTAS tests
    6ae77ce [Alex Liu] [SPARK-4943][SQL] fix TestHive matching error
    3652997 [Alex Liu] [SPARK-4943][SQL] Allow table name having dot to support db/catalog ...
    ---
     .../apache/spark/sql/catalyst/SqlParser.scala |   6 +-
     .../sql/catalyst/analysis/Analyzer.scala      |   8 +-
     .../spark/sql/catalyst/analysis/Catalog.scala | 106 +++++++++---------
     .../sql/catalyst/analysis/unresolved.scala    |   3 +-
     .../spark/sql/catalyst/dsl/package.scala      |   2 +-
     .../sql/catalyst/analysis/AnalysisSuite.scala |  20 ++--
     .../analysis/DecimalPrecisionSuite.scala      |   2 +-
     .../org/apache/spark/sql/SQLContext.scala     |   6 +-
     .../org/apache/spark/sql/SchemaRDDLike.scala  |   4 +-
     .../org/apache/spark/sql/JoinSuite.scala      |   4 +-
     .../apache/spark/sql/hive/HiveContext.scala   |   2 +-
     .../spark/sql/hive/HiveMetastoreCatalog.scala |  52 ++++++---
     .../org/apache/spark/sql/hive/HiveQl.scala    |  29 +++--
     .../org/apache/spark/sql/hive/TestHive.scala  |   2 +-
     .../hive/execution/CreateTableAsSelect.scala  |   4 +-
     .../spark/sql/hive/execution/commands.scala   |   2 +-
     .../spark/sql/hive/StatisticsSuite.scala      |   4 +-
     17 files changed, 143 insertions(+), 113 deletions(-)
    
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    index f79d4ff444dc..fc7b8745590d 100755
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    @@ -178,10 +178,10 @@ class SqlParser extends AbstractSparkSQLParser {
         joinedRelation | relationFactor
     
       protected lazy val relationFactor: Parser[LogicalPlan] =
    -    ( ident ~ (opt(AS) ~> opt(ident)) ^^ {
    -        case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
    +    ( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ {
    +        case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
           }
    -    | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
    +      | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
         )
     
       protected lazy val joinedRelation: Parser[LogicalPlan] =
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
    index 72680f37a0b4..c009cc1e1e85 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
    @@ -228,11 +228,11 @@ class Analyzer(catalog: Catalog,
        */
       object ResolveRelations extends Rule[LogicalPlan] {
         def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    -      case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
    +      case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
             i.copy(
    -          table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
    -      case UnresolvedRelation(databaseName, name, alias) =>
    -        catalog.lookupRelation(databaseName, name, alias)
    +          table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias)))
    +      case UnresolvedRelation(tableIdentifier, alias) =>
    +        catalog.lookupRelation(tableIdentifier, alias)
         }
       }
     
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
    index 0415d74bd814..df8d03b86c53 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
    @@ -28,77 +28,74 @@ trait Catalog {
     
       def caseSensitive: Boolean
     
    -  def tableExists(db: Option[String], tableName: String): Boolean
    +  def tableExists(tableIdentifier: Seq[String]): Boolean
     
       def lookupRelation(
    -    databaseName: Option[String],
    -    tableName: String,
    -    alias: Option[String] = None): LogicalPlan
    +      tableIdentifier: Seq[String],
    +      alias: Option[String] = None): LogicalPlan
     
    -  def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
    +  def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
     
    -  def unregisterTable(databaseName: Option[String], tableName: String): Unit
    +  def unregisterTable(tableIdentifier: Seq[String]): Unit
     
       def unregisterAllTables(): Unit
     
    -  protected def processDatabaseAndTableName(
    -      databaseName: Option[String],
    -      tableName: String): (Option[String], String) = {
    +  protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
         if (!caseSensitive) {
    -      (databaseName.map(_.toLowerCase), tableName.toLowerCase)
    +      tableIdentifier.map(_.toLowerCase)
         } else {
    -      (databaseName, tableName)
    +      tableIdentifier
         }
       }
     
    -  protected def processDatabaseAndTableName(
    -      databaseName: String,
    -      tableName: String): (String, String) = {
    -    if (!caseSensitive) {
    -      (databaseName.toLowerCase, tableName.toLowerCase)
    +  protected def getDbTableName(tableIdent: Seq[String]): String = {
    +    val size = tableIdent.size
    +    if (size <= 2) {
    +      tableIdent.mkString(".")
         } else {
    -      (databaseName, tableName)
    +      tableIdent.slice(size - 2, size).mkString(".")
         }
       }
    +
    +  protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
    +    (tableIdent.lift(tableIdent.size - 2), tableIdent.last)
    +  }
     }
     
     class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
       val tables = new mutable.HashMap[String, LogicalPlan]()
     
       override def registerTable(
    -      databaseName: Option[String],
    -      tableName: String,
    +      tableIdentifier: Seq[String],
           plan: LogicalPlan): Unit = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    tables += ((tblName, plan))
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    tables += ((getDbTableName(tableIdent), plan))
       }
     
    -  override def unregisterTable(
    -      databaseName: Option[String],
    -      tableName: String) = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    tables -= tblName
    +  override def unregisterTable(tableIdentifier: Seq[String]) = {
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    tables -= getDbTableName(tableIdent)
       }
     
       override def unregisterAllTables() = {
         tables.clear()
       }
     
    -  override def tableExists(db: Option[String], tableName: String): Boolean = {
    -    val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
    -    tables.get(tblName) match {
    +  override def tableExists(tableIdentifier: Seq[String]): Boolean = {
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    tables.get(getDbTableName(tableIdent)) match {
           case Some(_) => true
           case None => false
         }
       }
     
       override def lookupRelation(
    -      databaseName: Option[String],
    -      tableName: String,
    +      tableIdentifier: Seq[String],
           alias: Option[String] = None): LogicalPlan = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName"))
    -    val tableWithQualifiers = Subquery(tblName, table)
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    val tableFullName = getDbTableName(tableIdent)
    +    val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName"))
    +    val tableWithQualifiers = Subquery(tableIdent.last, table)
     
         // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
         // properly qualified with this alias.
    @@ -117,41 +114,39 @@ trait OverrideCatalog extends Catalog {
       // TODO: This doesn't work when the database changes...
       val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
     
    -  abstract override def tableExists(db: Option[String], tableName: String): Boolean = {
    -    val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
    -    overrides.get((dbName, tblName)) match {
    +  abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    overrides.get(getDBTable(tableIdent)) match {
           case Some(_) => true
    -      case None => super.tableExists(db, tableName)
    +      case None => super.tableExists(tableIdentifier)
         }
       }
     
       abstract override def lookupRelation(
    -    databaseName: Option[String],
    -    tableName: String,
    +    tableIdentifier: Seq[String],
         alias: Option[String] = None): LogicalPlan = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    val overriddenTable = overrides.get((dbName, tblName))
    -    val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r))
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    val overriddenTable = overrides.get(getDBTable(tableIdent))
    +    val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
     
         // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
         // properly qualified with this alias.
         val withAlias =
           tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r))
     
    -    withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias))
    +    withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
       }
     
       override def registerTable(
    -      databaseName: Option[String],
    -      tableName: String,
    +      tableIdentifier: Seq[String],
           plan: LogicalPlan): Unit = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    overrides.put((dbName, tblName), plan)
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    overrides.put(getDBTable(tableIdent), plan)
       }
     
    -  override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
    -    val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
    -    overrides.remove((dbName, tblName))
    +  override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    overrides.remove(getDBTable(tableIdent))
       }
     
       override def unregisterAllTables(): Unit = {
    @@ -167,22 +162,21 @@ object EmptyCatalog extends Catalog {
     
       val caseSensitive: Boolean = true
     
    -  def tableExists(db: Option[String], tableName: String): Boolean = {
    +  def tableExists(tableIdentifier: Seq[String]): Boolean = {
         throw new UnsupportedOperationException
       }
     
       def lookupRelation(
    -    databaseName: Option[String],
    -    tableName: String,
    +    tableIdentifier: Seq[String],
         alias: Option[String] = None) = {
         throw new UnsupportedOperationException
       }
     
    -  def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
    +  def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
         throw new UnsupportedOperationException
       }
     
    -  def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
    +  def unregisterTable(tableIdentifier: Seq[String]): Unit = {
         throw new UnsupportedOperationException
       }
     
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
    index 77d84e1687e1..71a738a0b2ca 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
    @@ -34,8 +34,7 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
      * Holds the name of a relation that has yet to be looked up in a [[Catalog]].
      */
     case class UnresolvedRelation(
    -    databaseName: Option[String],
    -    tableName: String,
    +    tableIdentifier: Seq[String],
         alias: Option[String] = None) extends LeafNode {
       override def output = Nil
       override lazy val resolved = false
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
    index 9608e15c0f30..b2262e5e6efb 100755
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
    @@ -290,7 +290,7 @@ package object dsl {
     
         def insertInto(tableName: String, overwrite: Boolean = false) =
           InsertIntoTable(
    -        analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)
    +        analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)
     
         def analyze = analysis.SimpleAnalyzer(logicalPlan)
       }
    diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
    index 82f2101d8ce1..f430057ef719 100644
    --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
    +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
    @@ -44,8 +44,8 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
         AttributeReference("e", ShortType)())
     
       before {
    -    caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
    -    caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
    +    caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
    +    caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
       }
     
       test("union project *") {
    @@ -64,45 +64,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
         assert(
           caseSensitiveAnalyze(
             Project(Seq(UnresolvedAttribute("TbL.a")),
    -          UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
    +          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
             Project(testRelation.output, testRelation))
     
         val e = intercept[TreeNodeException[_]] {
           caseSensitiveAnalyze(
             Project(Seq(UnresolvedAttribute("tBl.a")),
    -          UnresolvedRelation(None, "TaBlE", Some("TbL"))))
    +          UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
         }
         assert(e.getMessage().toLowerCase.contains("unresolved"))
     
         assert(
           caseInsensitiveAnalyze(
             Project(Seq(UnresolvedAttribute("TbL.a")),
    -          UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
    +          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
             Project(testRelation.output, testRelation))
     
         assert(
           caseInsensitiveAnalyze(
             Project(Seq(UnresolvedAttribute("tBl.a")),
    -          UnresolvedRelation(None, "TaBlE", Some("TbL")))) ===
    +          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
             Project(testRelation.output, testRelation))
       }
     
       test("resolve relations") {
         val e = intercept[RuntimeException] {
    -      caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
    +      caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
         }
         assert(e.getMessage == "Table Not Found: tAbLe")
     
         assert(
    -      caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
    +      caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
             testRelation)
     
         assert(
    -      caseInsensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) ===
    +      caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
             testRelation)
     
         assert(
    -      caseInsensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
    +      caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
             testRelation)
       }
     
    diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
    index 3677a6e72e23..bbbeb4f2e4fe 100644
    --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
    +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
    @@ -41,7 +41,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
       val f: Expression = UnresolvedAttribute("f")
     
       before {
    -    catalog.registerTable(None, "table", relation)
    +    catalog.registerTable(Seq("table"), relation)
       }
     
       private def checkType(expression: Expression, expectedType: DataType): Unit = {
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    index 6a1a4d995bf6..9962937277da 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    @@ -276,7 +276,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
        * @group userf
        */
       def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
    -    catalog.registerTable(None, tableName, rdd.queryExecution.logical)
    +    catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
       }
     
       /**
    @@ -289,7 +289,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
        */
       def dropTempTable(tableName: String): Unit = {
         tryUncacheQuery(table(tableName))
    -    catalog.unregisterTable(None, tableName)
    +    catalog.unregisterTable(Seq(tableName))
       }
     
       /**
    @@ -308,7 +308,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
     
       /** Returns the specified table as a SchemaRDD */
       def table(tableName: String): SchemaRDD =
    -    new SchemaRDD(this, catalog.lookupRelation(None, tableName))
    +    new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))
     
       /**
        * :: DeveloperApi ::
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
    index fd5f4abcbcd6..3cf9209465b7 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
    @@ -97,8 +97,8 @@ private[sql] trait SchemaRDDLike {
        */
       @Experimental
       def insertInto(tableName: String, overwrite: Boolean): Unit =
    -    sqlContext.executePlan(
    -      InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)).toRdd
    +    sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
    +      Map.empty, logicalPlan, overwrite)).toRdd
     
       /**
        * :: Experimental ::
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
    index 1a4232dab86e..c7e136388fce 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
    @@ -302,8 +302,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
         upperCaseData.where('N <= 4).registerTempTable("left")
         upperCaseData.where('N >= 3).registerTempTable("right")
     
    -    val left = UnresolvedRelation(None, "left", None)
    -    val right = UnresolvedRelation(None, "right", None)
    +    val left = UnresolvedRelation(Seq("left"), None)
    +    val right = UnresolvedRelation(Seq("right"), None)
     
         checkAnswer(
           left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    index 982e0593fcfd..1648fa826b90 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    @@ -124,7 +124,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
        * in the Hive metastore.
        */
       def analyze(tableName: String) {
    -    val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName))
    +    val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName)))
     
         relation match {
           case relation: MetastoreRelation =>
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    index b31a3ec25096..2c859894cf8d 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.metastore.TableType
     import org.apache.hadoop.hive.metastore.api.FieldSchema
     import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
     import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException}
    +import org.apache.hadoop.hive.ql.metadata.InvalidTableException
     import org.apache.hadoop.hive.ql.plan.CreateTableDesc
     import org.apache.hadoop.hive.serde.serdeConstants
     import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
    @@ -57,18 +58,25 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
     
       val caseSensitive: Boolean = false
     
    -  def tableExists(db: Option[String], tableName: String): Boolean = {
    -    val (databaseName, tblName) = processDatabaseAndTableName(
    -      db.getOrElse(hive.sessionState.getCurrentDatabase), tableName)
    -    client.getTable(databaseName, tblName, false) != null
    +  def tableExists(tableIdentifier: Seq[String]): Boolean = {
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
    +      hive.sessionState.getCurrentDatabase)
    +    val tblName = tableIdent.last
    +    try {
    +      client.getTable(databaseName, tblName) != null
    +    } catch {
    +      case ie: InvalidTableException => false
    +    }
       }
     
       def lookupRelation(
    -      db: Option[String],
    -      tableName: String,
    +      tableIdentifier: Seq[String],
           alias: Option[String]): LogicalPlan = synchronized {
    -    val (databaseName, tblName) =
    -      processDatabaseAndTableName(db.getOrElse(hive.sessionState.getCurrentDatabase), tableName)
    +    val tableIdent = processTableIdentifier(tableIdentifier)
    +    val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
    +      hive.sessionState.getCurrentDatabase)
    +    val tblName = tableIdent.last
         val table = client.getTable(databaseName, tblName)
         if (table.isView) {
           // if the unresolved relation is from hive view
    @@ -251,6 +259,26 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
         }
       }
     
    +  protected def processDatabaseAndTableName(
    +      databaseName: Option[String],
    +      tableName: String): (Option[String], String) = {
    +    if (!caseSensitive) {
    +      (databaseName.map(_.toLowerCase), tableName.toLowerCase)
    +    } else {
    +      (databaseName, tableName)
    +    }
    +  }
    +
    +  protected def processDatabaseAndTableName(
    +      databaseName: String,
    +      tableName: String): (String, String) = {
    +    if (!caseSensitive) {
    +      (databaseName.toLowerCase, tableName.toLowerCase)
    +    } else {
    +      (databaseName, tableName)
    +    }
    +  }
    +
       /**
        * Creates any tables required for query execution.
        * For example, because of a CREATE TABLE X AS statement.
    @@ -270,7 +298,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
             val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
     
             // Get the CreateTableDesc from Hive SemanticAnalyzer
    -        val desc: Option[CreateTableDesc] = if (tableExists(Some(databaseName), tblName)) {
    +        val desc: Option[CreateTableDesc] = if (tableExists(Seq(databaseName, tblName))) {
               None
             } else {
               val sa = new SemanticAnalyzer(hive.hiveconf) {
    @@ -352,15 +380,13 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
        * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
        * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]].
        */
    -  override def registerTable(
    -      databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
    +  override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ???
     
       /**
        * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
        * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]].
        */
    -  override def unregisterTable(
    -      databaseName: Option[String], tableName: String): Unit = ???
    +  override def unregisterTable(tableIdentifier: Seq[String]): Unit = ???
     
       override def unregisterAllTables() = {}
     }
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    index 8a9613cf96e5..c2ab3579c1f9 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    @@ -386,6 +386,15 @@ private[hive] object HiveQl {
         (db, tableName)
       }
     
    +  protected def extractTableIdent(tableNameParts: Node): Seq[String] = {
    +    tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match {
    +      case Seq(tableOnly) => Seq(tableOnly)
    +      case Seq(databaseName, table) => Seq(databaseName, table)
    +      case other => sys.error("Hive only supports tables names like 'tableName' " +
    +        s"or 'databaseName.tableName', found '$other'")
    +    }
    +  }
    +
       /**
        * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) 
        * is equivalent to 
    @@ -475,16 +484,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                   case Token(".", dbName :: tableName :: Nil) =>
                     // It is describing a table with the format like "describe db.table".
                     // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue.
    -                val (db, tableName) = extractDbNameTableName(nameParts.head)
    +                val tableIdent = extractTableIdent(nameParts.head)
                     DescribeCommand(
    -                  UnresolvedRelation(db, tableName, None), extended.isDefined)
    +                  UnresolvedRelation(tableIdent, None), extended.isDefined)
                   case Token(".", dbName :: tableName :: colName :: Nil) =>
                     // It is describing a column with the format like "describe db.table column".
                     NativePlaceholder
                   case tableName =>
                     // It is describing a table with the format like "describe table".
                     DescribeCommand(
    -                  UnresolvedRelation(None, tableName.getText, None),
    +                  UnresolvedRelation(Seq(tableName.getText), None),
                       extended.isDefined)
                 }
               }
    @@ -757,13 +766,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
               nonAliasClauses)
           }
     
    -      val (db, tableName) =
    +      val tableIdent =
             tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
    -          case Seq(tableOnly) => (None, tableOnly)
    -          case Seq(databaseName, table) => (Some(databaseName), table)
    +          case Seq(tableOnly) => Seq(tableOnly)
    +          case Seq(databaseName, table) => Seq(databaseName, table)
    +          case other => sys.error("Hive only supports tables names like 'tableName' " +
    +            s"or 'databaseName.tableName', found '$other'")
           }
           val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
    -      val relation = UnresolvedRelation(db, tableName, alias)
    +      val relation = UnresolvedRelation(tableIdent, alias)
     
           // Apply sampling if requested.
           (bucketSampleClause orElse splitSampleClause).map {
    @@ -882,7 +893,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
           val Some(tableNameParts) :: partitionClause :: Nil =
             getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
     
    -      val (db, tableName) = extractDbNameTableName(tableNameParts)
    +      val tableIdent = extractTableIdent(tableNameParts)
     
           val partitionKeys = partitionClause.map(_.getChildren.map {
             // Parse partitions. We also make keys case insensitive.
    @@ -892,7 +903,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
               cleanIdentifier(key.toLowerCase) -> None
           }.toMap).getOrElse(Map.empty)
     
    -      InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite)
    +      InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite)
     
         case a: ASTNode =>
           throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    index b2149bd95a33..8f2311cf83eb 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    @@ -167,7 +167,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
           // Make sure any test tables referenced are loaded.
           val referencedTables =
             describedTables ++
    -        logical.collect { case UnresolvedRelation(databaseName, name, _) => name }
    +        logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last }
           val referencedTestTables = referencedTables.filter(testTables.contains)
           logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}")
           referencedTestTables.foreach(loadTestTable)
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
    index fe21454e7fb3..a547babcebff 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
    @@ -53,14 +53,14 @@ case class CreateTableAsSelect(
           hiveContext.catalog.createTable(database, tableName, query.output, allowExisting, desc)
     
           // Get the Metastore Relation
    -      hiveContext.catalog.lookupRelation(Some(database), tableName, None) match {
    +      hiveContext.catalog.lookupRelation(Seq(database, tableName), None) match {
             case r: MetastoreRelation => r
           }
         }
         // TODO ideally, we should get the output data ready first and then
         // add the relation into catalog, just in case of failure occurs while data
         // processing.
    -    if (hiveContext.catalog.tableExists(Some(database), tableName)) {
    +    if (hiveContext.catalog.tableExists(Seq(database, tableName))) {
           if (allowExisting) {
             // table already exists, will do nothing, to keep consistent with Hive
           } else {
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
    index 6fc4153f6a5d..6b733a280e6d 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
    @@ -53,7 +53,7 @@ case class DropTable(
         val hiveContext = sqlContext.asInstanceOf[HiveContext]
         val ifExistsClause = if (ifExists) "IF EXISTS " else ""
         hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
    -    hiveContext.catalog.unregisterTable(None, tableName)
    +    hiveContext.catalog.unregisterTable(Seq(tableName))
         Seq.empty[Row]
       }
     }
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
    index 4b6a9308b981..a758f921e041 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
    @@ -72,7 +72,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
     
       test("analyze MetastoreRelations") {
         def queryTotalSize(tableName: String): BigInt =
    -      catalog.lookupRelation(None, tableName).statistics.sizeInBytes
    +      catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes
     
         // Non-partitioned table
         sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect()
    @@ -123,7 +123,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
         intercept[NotImplementedError] {
           analyze("tempTable")
         }
    -    catalog.unregisterTable(None, "tempTable")
    +    catalog.unregisterTable(Seq("tempTable"))
       }
     
       test("estimates the size of a test MetastoreRelation") {
    
    From 693a323a70aba91e6c100dd5561d218a75b7895e Mon Sep 17 00:00:00 2001
    From: scwf 
    Date: Sat, 10 Jan 2015 13:53:21 -0800
    Subject: [PATCH 109/215] [SPARK-4574][SQL] Adding support for defining schema
     in foreign DDL commands.
    
    Adding support for defining schema in foreign DDL commands. Now foreign DDL support commands like:
    ```
    CREATE TEMPORARY TABLE avroTable
    USING org.apache.spark.sql.avro
    OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
    ```
    With this PR user can define schema instead of infer from file, so  support ddl command as follows:
    ```
    CREATE TEMPORARY TABLE avroTable(a int, b string)
    USING org.apache.spark.sql.avro
    OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
    ```
    
    Author: scwf 
    Author: Yin Huai 
    Author: Fei Wang 
    Author: wangfei 
    
    Closes #3431 from scwf/ddl and squashes the following commits:
    
    7e79ce5 [Fei Wang] Merge pull request #22 from yhuai/pr3431yin
    38f634e [Yin Huai] Remove Option from createRelation.
    65e9c73 [Yin Huai] Revert all changes since applying a given schema has not been testd.
    a852b10 [scwf] remove cleanIdentifier
    f336a16 [Fei Wang] Merge pull request #21 from yhuai/pr3431yin
    baf79b5 [Yin Huai] Test special characters quoted by backticks.
    50a03b0 [Yin Huai] Use JsonRDD.nullTypeToStringType to convert NullType to StringType.
    1eeb769 [Fei Wang] Merge pull request #20 from yhuai/pr3431yin
    f5c22b0 [Yin Huai] Refactor code and update test cases.
    f1cffe4 [Yin Huai] Revert "minor refactory"
    b621c8f [scwf] minor refactory
    d02547f [scwf] fix HiveCompatibilitySuite test failure
    8dfbf7a [scwf] more tests for complex data type
    ddab984 [Fei Wang] Merge pull request #19 from yhuai/pr3431yin
    91ad91b [Yin Huai] Parse data types in DDLParser.
    cf982d2 [scwf] fixed test failure
    445b57b [scwf] address comments
    02a662c [scwf] style issue
    44eb70c [scwf] fix decimal parser issue
    83b6fc3 [scwf] minor fix
    9bf12f8 [wangfei] adding test case
    7787ec7 [wangfei] added SchemaRelationProvider
    0ba70df [wangfei] draft version
    ---
     .../apache/spark/sql/json/JSONRelation.scala  |  35 +++-
     .../org/apache/spark/sql/sources/ddl.scala    | 138 +++++++++++--
     .../apache/spark/sql/sources/interfaces.scala |  29 ++-
     .../spark/sql/sources/TableScanSuite.scala    | 192 ++++++++++++++++++
     .../spark/sql/hive/HiveMetastoreCatalog.scala | 114 +++--------
     .../sql/hive/HiveMetastoreCatalogSuite.scala  |   5 +-
     6 files changed, 400 insertions(+), 113 deletions(-)
    
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
    index fc70c183437f..a9a6696cb15e 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
    @@ -18,31 +18,48 @@
     package org.apache.spark.sql.json
     
     import org.apache.spark.sql.SQLContext
    +import org.apache.spark.sql.catalyst.types.StructType
     import org.apache.spark.sql.sources._
     
    -private[sql] class DefaultSource extends RelationProvider {
    -  /** Returns a new base relation with the given parameters. */
    +private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider {
    +
    +  /** Returns a new base relation with the parameters. */
       override def createRelation(
           sqlContext: SQLContext,
           parameters: Map[String, String]): BaseRelation = {
         val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
         val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
     
    -    JSONRelation(fileName, samplingRatio)(sqlContext)
    +    JSONRelation(fileName, samplingRatio, None)(sqlContext)
    +  }
    +
    +  /** Returns a new base relation with the given schema and parameters. */
    +  override def createRelation(
    +      sqlContext: SQLContext,
    +      parameters: Map[String, String],
    +      schema: StructType): BaseRelation = {
    +    val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
    +    val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
    +
    +    JSONRelation(fileName, samplingRatio, Some(schema))(sqlContext)
       }
     }
     
    -private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
    +private[sql] case class JSONRelation(
    +    fileName: String,
    +    samplingRatio: Double,
    +    userSpecifiedSchema: Option[StructType])(
         @transient val sqlContext: SQLContext)
       extends TableScan {
     
       private def baseRDD = sqlContext.sparkContext.textFile(fileName)
     
    -  override val schema =
    -    JsonRDD.inferSchema(
    -      baseRDD,
    -      samplingRatio,
    -      sqlContext.columnNameOfCorruptRecord)
    +  override val schema = userSpecifiedSchema.getOrElse(
    +    JsonRDD.nullTypeToStringType(
    +      JsonRDD.inferSchema(
    +        baseRDD,
    +        samplingRatio,
    +        sqlContext.columnNameOfCorruptRecord)))
     
       override def buildScan() =
         JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    index 364bacec83b9..fe2c4d8436b2 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
    @@ -17,16 +17,15 @@
     
     package org.apache.spark.sql.sources
     
    -import org.apache.spark.Logging
    -import org.apache.spark.sql.SQLContext
    -import org.apache.spark.sql.execution.RunnableCommand
    -import org.apache.spark.util.Utils
    -
     import scala.language.implicitConversions
    -import scala.util.parsing.combinator.lexical.StdLexical
     import scala.util.parsing.combinator.syntactical.StandardTokenParsers
     import scala.util.parsing.combinator.PackratParsers
     
    +import org.apache.spark.Logging
    +import org.apache.spark.sql.SQLContext
    +import org.apache.spark.sql.catalyst.types._
    +import org.apache.spark.sql.execution.RunnableCommand
    +import org.apache.spark.util.Utils
     import org.apache.spark.sql.catalyst.plans.logical._
     import org.apache.spark.sql.catalyst.SqlLexical
     
    @@ -44,6 +43,14 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
         }
       }
     
    +  def parseType(input: String): DataType = {
    +    phrase(dataType)(new lexical.Scanner(input)) match {
    +      case Success(r, x) => r
    +      case x =>
    +        sys.error(s"Unsupported dataType: $x")
    +    }
    +  }
    +
       protected case class Keyword(str: String)
     
       protected implicit def asParser(k: Keyword): Parser[String] =
    @@ -55,6 +62,24 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
       protected val USING = Keyword("USING")
       protected val OPTIONS = Keyword("OPTIONS")
     
    +  // Data types.
    +  protected val STRING = Keyword("STRING")
    +  protected val BINARY = Keyword("BINARY")
    +  protected val BOOLEAN = Keyword("BOOLEAN")
    +  protected val TINYINT = Keyword("TINYINT")
    +  protected val SMALLINT = Keyword("SMALLINT")
    +  protected val INT = Keyword("INT")
    +  protected val BIGINT = Keyword("BIGINT")
    +  protected val FLOAT = Keyword("FLOAT")
    +  protected val DOUBLE = Keyword("DOUBLE")
    +  protected val DECIMAL = Keyword("DECIMAL")
    +  protected val DATE = Keyword("DATE")
    +  protected val TIMESTAMP = Keyword("TIMESTAMP")
    +  protected val VARCHAR = Keyword("VARCHAR")
    +  protected val ARRAY = Keyword("ARRAY")
    +  protected val MAP = Keyword("MAP")
    +  protected val STRUCT = Keyword("STRUCT")
    +
       // Use reflection to find the reserved words defined in this class.
       protected val reservedWords =
         this.getClass
    @@ -67,15 +92,25 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
       protected lazy val ddl: Parser[LogicalPlan] = createTable
     
       /**
    -   * CREATE TEMPORARY TABLE avroTable
    +   * `CREATE TEMPORARY TABLE avroTable
        * USING org.apache.spark.sql.avro
    -   * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
    +   * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
    +   * or
    +   * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
    +   * USING org.apache.spark.sql.avro
    +   * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
        */
       protected lazy val createTable: Parser[LogicalPlan] =
    -    CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
    -      case tableName ~ provider ~ opts =>
    -        CreateTableUsing(tableName, provider, opts)
    +  (
    +    CREATE ~ TEMPORARY ~ TABLE ~> ident
    +      ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
    +      case tableName ~ columns ~ provider ~ opts =>
    +        val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
    +        CreateTableUsing(tableName, userSpecifiedSchema, provider, opts)
         }
    +  )
    +
    +  protected lazy val tableCols: Parser[Seq[StructField]] =  "(" ~> repsep(column, ",") <~ ")"
     
       protected lazy val options: Parser[Map[String, String]] =
         "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
    @@ -83,10 +118,66 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
       protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
     
       protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) }
    +
    +  protected lazy val column: Parser[StructField] =
    +    ident ~ dataType ^^ { case columnName ~ typ =>
    +      StructField(columnName, typ)
    +    }
    +
    +  protected lazy val primitiveType: Parser[DataType] =
    +    STRING ^^^ StringType |
    +    BINARY ^^^ BinaryType |
    +    BOOLEAN ^^^ BooleanType |
    +    TINYINT ^^^ ByteType |
    +    SMALLINT ^^^ ShortType |
    +    INT ^^^ IntegerType |
    +    BIGINT ^^^ LongType |
    +    FLOAT ^^^ FloatType |
    +    DOUBLE ^^^ DoubleType |
    +    fixedDecimalType |                   // decimal with precision/scale
    +    DECIMAL ^^^ DecimalType.Unlimited |  // decimal with no precision/scale
    +    DATE ^^^ DateType |
    +    TIMESTAMP ^^^ TimestampType |
    +    VARCHAR ~ "(" ~ numericLit ~ ")" ^^^ StringType
    +
    +  protected lazy val fixedDecimalType: Parser[DataType] =
    +    (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
    +      case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
    +    }
    +
    +  protected lazy val arrayType: Parser[DataType] =
    +    ARRAY ~> "<" ~> dataType <~ ">" ^^ {
    +      case tpe => ArrayType(tpe)
    +    }
    +
    +  protected lazy val mapType: Parser[DataType] =
    +    MAP ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
    +      case t1 ~ _ ~ t2 => MapType(t1, t2)
    +    }
    +
    +  protected lazy val structField: Parser[StructField] =
    +    ident ~ ":" ~ dataType ^^ {
    +      case fieldName ~ _ ~ tpe => StructField(fieldName, tpe, nullable = true)
    +    }
    +
    +  protected lazy val structType: Parser[DataType] =
    +    (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
    +    case fields => new StructType(fields)
    +    }) |
    +    (STRUCT ~> "<>" ^^ {
    +      case fields => new StructType(Nil)
    +    })
    +
    +  private[sql] lazy val dataType: Parser[DataType] =
    +    arrayType |
    +    mapType |
    +    structType |
    +    primitiveType
     }
     
     private[sql] case class CreateTableUsing(
         tableName: String,
    +    userSpecifiedSchema: Option[StructType],
         provider: String,
         options: Map[String, String]) extends RunnableCommand {
     
    @@ -99,8 +190,29 @@ private[sql] case class CreateTableUsing(
                 sys.error(s"Failed to load class for data source: $provider")
             }
         }
    -    val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
    -    val relation = dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
    +
    +    val relation = userSpecifiedSchema match {
    +      case Some(schema: StructType) => {
    +        clazz.newInstance match {
    +          case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
    +            dataSource
    +              .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider]
    +              .createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
    +          case _ =>
    +            sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.")
    +        }
    +      }
    +      case None => {
    +        clazz.newInstance match {
    +          case dataSource: org.apache.spark.sql.sources.RelationProvider  =>
    +            dataSource
    +              .asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
    +              .createRelation(sqlContext, new CaseInsensitiveMap(options))
    +          case _ =>
    +            sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.")
    +        }
    +      }
    +    }
     
         sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
         Seq.empty
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
    index 02eff80456db..990f7e0e74bc 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
    @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
     
     import org.apache.spark.annotation.{Experimental, DeveloperApi}
     import org.apache.spark.rdd.RDD
    -import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
    +import org.apache.spark.sql.{Row, SQLContext, StructType}
     import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
     
     /**
    @@ -44,6 +44,33 @@ trait RelationProvider {
       def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
     }
     
    +/**
    + * ::DeveloperApi::
    + * Implemented by objects that produce relations for a specific kind of data source.  When
    + * Spark SQL is given a DDL operation with
    + * 1. USING clause: to specify the implemented SchemaRelationProvider
    + * 2. User defined schema: users can define schema optionally when create table
    + *
    + * Users may specify the fully qualified class name of a given data source.  When that class is
    + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for
    + * less verbose invocation.  For example, 'org.apache.spark.sql.json' would resolve to the
    + * data source 'org.apache.spark.sql.json.DefaultSource'
    + *
    + * A new instance of this class with be instantiated each time a DDL call is made.
    + */
    +@DeveloperApi
    +trait SchemaRelationProvider {
    +  /**
    +   * Returns a new base relation with the given parameters and user defined schema.
    +   * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
    +   * by the Map that is passed to the function.
    +   */
    +  def createRelation(
    +      sqlContext: SQLContext,
    +      parameters: Map[String, String],
    +      schema: StructType): BaseRelation
    +}
    +
     /**
      * ::DeveloperApi::
      * Represents a collection of tuples with a known schema.  Classes that extend BaseRelation must
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
    index 3cd7b0115d56..605190f5ae6a 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
    @@ -17,7 +17,10 @@
     
     package org.apache.spark.sql.sources
     
    +import java.sql.{Timestamp, Date}
    +
     import org.apache.spark.sql._
    +import org.apache.spark.sql.catalyst.types.DecimalType
     
     class DefaultSource extends SimpleScanSource
     
    @@ -38,9 +41,77 @@ case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
       override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
     }
     
    +class AllDataTypesScanSource extends SchemaRelationProvider {
    +  override def createRelation(
    +      sqlContext: SQLContext,
    +      parameters: Map[String, String],
    +      schema: StructType): BaseRelation = {
    +    AllDataTypesScan(parameters("from").toInt, parameters("TO").toInt, schema)(sqlContext)
    +  }
    +}
    +
    +case class AllDataTypesScan(
    +  from: Int,
    +  to: Int,
    +  userSpecifiedSchema: StructType)(@transient val sqlContext: SQLContext)
    +  extends TableScan {
    +
    +  override def schema = userSpecifiedSchema
    +
    +  override def buildScan() = {
    +    sqlContext.sparkContext.parallelize(from to to).map { i =>
    +      Row(
    +        s"str_$i",
    +        s"str_$i".getBytes(),
    +        i % 2 == 0,
    +        i.toByte,
    +        i.toShort,
    +        i,
    +        i.toLong,
    +        i.toFloat,
    +        i.toDouble,
    +        BigDecimal(i),
    +        BigDecimal(i),
    +        new Date((i + 1) * 8640000),
    +        new Timestamp(20000 + i),
    +        s"varchar_$i",
    +        Seq(i, i + 1),
    +        Seq(Map(s"str_$i" -> Row(i.toLong))),
    +        Map(i -> i.toString),
    +        Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
    +        Row(i, i.toString),
    +        Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
    +    }
    +  }
    +}
    +
     class TableScanSuite extends DataSourceTest {
       import caseInsensisitiveContext._
     
    +  var tableWithSchemaExpected = (1 to 10).map { i =>
    +    Row(
    +      s"str_$i",
    +      s"str_$i",
    +      i % 2 == 0,
    +      i.toByte,
    +      i.toShort,
    +      i,
    +      i.toLong,
    +      i.toFloat,
    +      i.toDouble,
    +      BigDecimal(i),
    +      BigDecimal(i),
    +      new Date((i + 1) * 8640000),
    +      new Timestamp(20000 + i),
    +      s"varchar_$i",
    +      Seq(i, i + 1),
    +      Seq(Map(s"str_$i" -> Row(i.toLong))),
    +      Map(i -> i.toString),
    +      Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)),
    +      Row(i, i.toString),
    +      Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000)))))
    +  }.toSeq
    +
       before {
         sql(
           """
    @@ -51,6 +122,37 @@ class TableScanSuite extends DataSourceTest {
             |  To '10'
             |)
           """.stripMargin)
    +
    +    sql(
    +      """
    +        |CREATE TEMPORARY TABLE tableWithSchema (
    +        |`string$%Field` stRIng,
    +        |binaryField binary,
    +        |`booleanField` boolean,
    +        |ByteField tinyint,
    +        |shortField smaLlint,
    +        |int_Field iNt,
    +        |`longField_:,<>=+/~^` Bigint,
    +        |floatField flOat,
    +        |doubleField doubLE,
    +        |decimalField1 decimal,
    +        |decimalField2 decimal(9,2),
    +        |dateField dAte,
    +        |timestampField tiMestamp,
    +        |varcharField varchaR(12),
    +        |arrayFieldSimple Array,
    +        |arrayFieldComplex Array>>,
    +        |mapFieldSimple MAP,
    +        |mapFieldComplex Map, Struct>,
    +        |structFieldSimple StRuct,
    +        |structFieldComplex StRuct, Value:struct<`value_(2)`:Array>>
    +        |)
    +        |USING org.apache.spark.sql.sources.AllDataTypesScanSource
    +        |OPTIONS (
    +        |  From '1',
    +        |  To '10'
    +        |)
    +      """.stripMargin)
       }
     
       sqlTest(
    @@ -73,6 +175,96 @@ class TableScanSuite extends DataSourceTest {
         "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1",
         (2 to 10).map(i => Row(i, i - 1)).toSeq)
     
    +  test("Schema and all fields") {
    +    val expectedSchema = StructType(
    +      StructField("string$%Field", StringType, true) ::
    +      StructField("binaryField", BinaryType, true) ::
    +      StructField("booleanField", BooleanType, true) ::
    +      StructField("ByteField", ByteType, true) ::
    +      StructField("shortField", ShortType, true) ::
    +      StructField("int_Field", IntegerType, true) ::
    +      StructField("longField_:,<>=+/~^", LongType, true) ::
    +      StructField("floatField", FloatType, true) ::
    +      StructField("doubleField", DoubleType, true) ::
    +      StructField("decimalField1", DecimalType.Unlimited, true) ::
    +      StructField("decimalField2", DecimalType(9, 2), true) ::
    +      StructField("dateField", DateType, true) ::
    +      StructField("timestampField", TimestampType, true) ::
    +      StructField("varcharField", StringType, true) ::
    +      StructField("arrayFieldSimple", ArrayType(IntegerType), true) ::
    +      StructField("arrayFieldComplex",
    +        ArrayType(
    +          MapType(StringType, StructType(StructField("key", LongType, true) :: Nil))), true) ::
    +      StructField("mapFieldSimple", MapType(IntegerType, StringType), true) ::
    +      StructField("mapFieldComplex",
    +        MapType(
    +          MapType(StringType, FloatType),
    +          StructType(StructField("key", LongType, true) :: Nil)), true) ::
    +      StructField("structFieldSimple",
    +        StructType(
    +          StructField("key", IntegerType, true) ::
    +          StructField("Value", StringType, true) ::  Nil), true) ::
    +      StructField("structFieldComplex",
    +        StructType(
    +          StructField("key", ArrayType(StringType), true) ::
    +          StructField("Value",
    +            StructType(
    +              StructField("value_(2)", ArrayType(DateType), true) :: Nil), true) ::  Nil), true) ::
    +      Nil
    +    )
    +
    +    assert(expectedSchema == table("tableWithSchema").schema)
    +
    +    checkAnswer(
    +      sql(
    +        """SELECT
    +          | `string$%Field`,
    +          | cast(binaryField as string),
    +          | booleanField,
    +          | byteField,
    +          | shortField,
    +          | int_Field,
    +          | `longField_:,<>=+/~^`,
    +          | floatField,
    +          | doubleField,
    +          | decimalField1,
    +          | decimalField2,
    +          | dateField,
    +          | timestampField,
    +          | varcharField,
    +          | arrayFieldSimple,
    +          | arrayFieldComplex,
    +          | mapFieldSimple,
    +          | mapFieldComplex,
    +          | structFieldSimple,
    +          | structFieldComplex FROM tableWithSchema""".stripMargin),
    +      tableWithSchemaExpected
    +    )
    +  }
    +
    +  sqlTest(
    +    "SELECT count(*) FROM tableWithSchema",
    +    10)
    +
    +  sqlTest(
    +    "SELECT `string$%Field` FROM tableWithSchema",
    +    (1 to 10).map(i => Row(s"str_$i")).toSeq)
    +
    +  sqlTest(
    +    "SELECT int_Field FROM tableWithSchema WHERE int_Field < 5",
    +    (1 to 4).map(Row(_)).toSeq)
    +
    +  sqlTest(
    +    "SELECT `longField_:,<>=+/~^` * 2 FROM tableWithSchema",
    +    (1 to 10).map(i => Row(i * 2.toLong)).toSeq)
    +
    +  sqlTest(
    +    "SELECT structFieldSimple.key, arrayFieldSimple[1] FROM tableWithSchema a where int_Field=1",
    +    Seq(Seq(1, 2)))
    +
    +  sqlTest(
    +    "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema",
    +    (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq)
     
       test("Caching")  {
         // Cached Query Execution
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    index 2c859894cf8d..c25288e00012 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
    @@ -20,12 +20,7 @@ package org.apache.spark.sql.hive
     import java.io.IOException
     import java.util.{List => JList}
     
    -import org.apache.spark.sql.execution.SparkPlan
    -
    -import scala.util.parsing.combinator.RegexParsers
    -
     import org.apache.hadoop.util.ReflectionUtils
    -
     import org.apache.hadoop.hive.metastore.TableType
     import org.apache.hadoop.hive.metastore.api.FieldSchema
     import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
    @@ -37,7 +32,6 @@ import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
     import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
     
     import org.apache.spark.Logging
    -import org.apache.spark.annotation.DeveloperApi
     import org.apache.spark.sql.SQLContext
     import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog}
     import org.apache.spark.sql.catalyst.expressions._
    @@ -412,88 +406,6 @@ private[hive] case class InsertIntoHiveTable(
       }
     }
     
    -/**
    - * :: DeveloperApi ::
    - * Provides conversions between Spark SQL data types and Hive Metastore types.
    - */
    -@DeveloperApi
    -object HiveMetastoreTypes extends RegexParsers {
    -  protected lazy val primitiveType: Parser[DataType] =
    -    "string" ^^^ StringType |
    -    "float" ^^^ FloatType |
    -    "int" ^^^ IntegerType |
    -    "tinyint" ^^^ ByteType |
    -    "smallint" ^^^ ShortType |
    -    "double" ^^^ DoubleType |
    -    "bigint" ^^^ LongType |
    -    "binary" ^^^ BinaryType |
    -    "boolean" ^^^ BooleanType |
    -    fixedDecimalType |                     // Hive 0.13+ decimal with precision/scale
    -    "decimal" ^^^ DecimalType.Unlimited |  // Hive 0.12 decimal with no precision/scale
    -    "date" ^^^ DateType |
    -    "timestamp" ^^^ TimestampType |
    -    "varchar\\((\\d+)\\)".r ^^^ StringType
    -
    -  protected lazy val fixedDecimalType: Parser[DataType] =
    -    ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
    -      case precision ~ scale =>
    -        DecimalType(precision.toInt, scale.toInt)
    -    }
    -
    -  protected lazy val arrayType: Parser[DataType] =
    -    "array" ~> "<" ~> dataType <~ ">" ^^ {
    -      case tpe => ArrayType(tpe)
    -    }
    -
    -  protected lazy val mapType: Parser[DataType] =
    -    "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
    -      case t1 ~ _ ~ t2 => MapType(t1, t2)
    -    }
    -
    -  protected lazy val structField: Parser[StructField] =
    -    "[a-zA-Z0-9_]*".r ~ ":" ~ dataType ^^ {
    -      case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
    -    }
    -
    -  protected lazy val structType: Parser[DataType] =
    -    "struct" ~> "<" ~> repsep(structField,",") <~ ">"  ^^ {
    -      case fields => new StructType(fields)
    -    }
    -
    -  protected lazy val dataType: Parser[DataType] =
    -    arrayType |
    -    mapType |
    -    structType |
    -    primitiveType
    -
    -  def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
    -    case Success(result, _) => result
    -    case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
    -  }
    -
    -  def toMetastoreType(dt: DataType): String = dt match {
    -    case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
    -    case StructType(fields) =>
    -      s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
    -    case MapType(keyType, valueType, _) =>
    -      s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
    -    case StringType => "string"
    -    case FloatType => "float"
    -    case IntegerType => "int"
    -    case ByteType => "tinyint"
    -    case ShortType => "smallint"
    -    case DoubleType => "double"
    -    case LongType => "bigint"
    -    case BinaryType => "binary"
    -    case BooleanType => "boolean"
    -    case DateType => "date"
    -    case d: DecimalType => HiveShim.decimalMetastoreString(d)
    -    case TimestampType => "timestamp"
    -    case NullType => "void"
    -    case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
    -  }
    -}
    -
     private[hive] case class MetastoreRelation
         (databaseName: String, tableName: String, alias: Option[String])
         (val table: TTable, val partitions: Seq[TPartition])
    @@ -551,7 +463,7 @@ private[hive] case class MetastoreRelation
       implicit class SchemaAttribute(f: FieldSchema) {
         def toAttribute = AttributeReference(
           f.getName,
    -      HiveMetastoreTypes.toDataType(f.getType),
    +      sqlContext.ddlParser.parseType(f.getType),
           // Since data can be dumped in randomly with no validation, everything is nullable.
           nullable = true
         )(qualifiers = Seq(alias.getOrElse(tableName)))
    @@ -571,3 +483,27 @@ private[hive] case class MetastoreRelation
       /** An attribute map for determining the ordinal for non-partition columns. */
       val columnOrdinals = AttributeMap(attributes.zipWithIndex)
     }
    +
    +object HiveMetastoreTypes {
    +  def toMetastoreType(dt: DataType): String = dt match {
    +    case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
    +    case StructType(fields) =>
    +      s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
    +    case MapType(keyType, valueType, _) =>
    +      s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
    +    case StringType => "string"
    +    case FloatType => "float"
    +    case IntegerType => "int"
    +    case ByteType => "tinyint"
    +    case ShortType => "smallint"
    +    case DoubleType => "double"
    +    case LongType => "bigint"
    +    case BinaryType => "binary"
    +    case BooleanType => "boolean"
    +    case DateType => "date"
    +    case d: DecimalType => HiveShim.decimalMetastoreString(d)
    +    case TimestampType => "timestamp"
    +    case NullType => "void"
    +    case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
    +  }
    +}
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
    index 86535f8dd4f5..041a36f1295e 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
    @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
     import org.scalatest.FunSuite
     
     import org.apache.spark.sql.catalyst.types.StructType
    +import org.apache.spark.sql.sources.DDLParser
     import org.apache.spark.sql.test.ExamplePointUDT
     
     class HiveMetastoreCatalogSuite extends FunSuite {
    @@ -27,7 +28,9 @@ class HiveMetastoreCatalogSuite extends FunSuite {
       test("struct field should accept underscore in sub-column name") {
         val metastr = "struct"
     
    -    val datatype = HiveMetastoreTypes.toDataType(metastr)
    +    val ddlParser = new DDLParser
    +
    +    val datatype = ddlParser.parseType(metastr)
         assert(datatype.isInstanceOf[StructType])
       }
     
    
    From b3e86dc62476abb03b330f86a788aa19a6565317 Mon Sep 17 00:00:00 2001
    From: scwf 
    Date: Sat, 10 Jan 2015 14:08:04 -0800
    Subject: [PATCH 110/215] [SPARK-4861][SQL] Refactory command in spark sql
    
    Follow up for #3712.
    This PR finally remove ```CommandStrategy``` and make all commands follow ```RunnableCommand``` so they can go with ```case r: RunnableCommand => ExecutedCommand(r) :: Nil```.
    
    One exception is the ```DescribeCommand``` of hive, which is a special case and need to distinguish hive table and temporary table, so still keep ```HiveCommandStrategy``` here.
    
    Author: scwf 
    
    Closes #3948 from scwf/followup-SPARK-4861 and squashes the following commits:
    
    6b48e64 [scwf] minor style fix
    2c62e9d [scwf] fix for hive module
    5a7a819 [scwf] Refactory command in spark sql
    ---
     ...ser.scala => AbstractSparkSQLParser.scala} | 69 -------------
     .../sql/catalyst/plans/logical/commands.scala | 48 +--------
     .../org/apache/spark/sql/SQLContext.scala     |  3 +-
     .../org/apache/spark/sql/SparkSQLParser.scala | 97 +++++++++++++++++++
     .../spark/sql/execution/SparkStrategies.scala | 20 +---
     .../apache/spark/sql/execution/commands.scala |  2 +-
     .../spark/sql/hive/thriftserver/Shim12.scala  |  4 +-
     .../spark/sql/hive/thriftserver/Shim13.scala  |  4 +-
     .../apache/spark/sql/hive/HiveContext.scala   |  4 +-
     .../org/apache/spark/sql/hive/HiveQl.scala    | 31 +++++-
     .../spark/sql/hive/HiveStrategies.scala       |  5 +-
     .../org/apache/spark/sql/hive/TestHive.scala  |  3 +-
     .../hive/execution/HiveComparisonTest.scala   |  2 +
     13 files changed, 141 insertions(+), 151 deletions(-)
     rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/{SparkSQLParser.scala => AbstractSparkSQLParser.scala} (62%)
     create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
    
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
    similarity index 62%
    rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
    rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
    index f1a1ca6616a2..93d74adbcc95 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
    @@ -105,72 +105,3 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
         }
       }
     }
    -
    -/**
    - * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
    - * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
    - *
    - * @param fallback A function that parses an input string to a logical plan
    - */
    -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
    -
    -  // A parser for the key-value part of the "SET [key = [value ]]" syntax
    -  private object SetCommandParser extends RegexParsers {
    -    private val key: Parser[String] = "(?m)[^=]+".r
    -
    -    private val value: Parser[String] = "(?m).*$".r
    -
    -    private val pair: Parser[LogicalPlan] =
    -      (key ~ ("=".r ~> value).?).? ^^ {
    -        case None => SetCommand(None)
    -        case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
    -      }
    -
    -    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
    -      case Success(plan, _) => plan
    -      case x => sys.error(x.toString)
    -    }
    -  }
    -
    -  protected val AS      = Keyword("AS")
    -  protected val CACHE   = Keyword("CACHE")
    -  protected val LAZY    = Keyword("LAZY")
    -  protected val SET     = Keyword("SET")
    -  protected val TABLE   = Keyword("TABLE")
    -  protected val UNCACHE = Keyword("UNCACHE")
    -
    -  protected implicit def asParser(k: Keyword): Parser[String] =
    -    lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
    -
    -  private val reservedWords: Seq[String] =
    -    this
    -      .getClass
    -      .getMethods
    -      .filter(_.getReturnType == classOf[Keyword])
    -      .map(_.invoke(this).asInstanceOf[Keyword].str)
    -
    -  override val lexical = new SqlLexical(reservedWords)
    -
    -  override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
    -
    -  private lazy val cache: Parser[LogicalPlan] =
    -    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
    -      case isLazy ~ tableName ~ plan =>
    -        CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
    -    }
    -
    -  private lazy val uncache: Parser[LogicalPlan] =
    -    UNCACHE ~ TABLE ~> ident ^^ {
    -      case tableName => UncacheTableCommand(tableName)
    -    }
    -
    -  private lazy val set: Parser[LogicalPlan] =
    -    SET ~> restInput ^^ {
    -      case input => SetCommandParser(input)
    -    }
    -
    -  private lazy val others: Parser[LogicalPlan] =
    -    wholeInput ^^ {
    -      case input => fallback(input)
    -    }
    -}
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
    index 5a1863953eae..45905f8ef98c 100644
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
    @@ -17,8 +17,7 @@
     
     package org.apache.spark.sql.catalyst.plans.logical
     
    -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
    -import org.apache.spark.sql.catalyst.types.StringType
    +import org.apache.spark.sql.catalyst.expressions.Attribute
     
     /**
      * A logical node that represents a non-query command to be executed by the system.  For example,
    @@ -28,48 +27,3 @@ abstract class Command extends LeafNode {
       self: Product =>
       def output: Seq[Attribute] = Seq.empty
     }
    -
    -/**
    - *
    - * Commands of the form "SET [key [= value] ]".
    - */
    -case class SetCommand(kv: Option[(String, Option[String])]) extends Command {
    -  override def output = Seq(
    -    AttributeReference("", StringType, nullable = false)())
    -}
    -
    -/**
    - * Returned by a parser when the users only wants to see what query plan would be executed, without
    - * actually performing the execution.
    - */
    -case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends Command {
    -  override def output =
    -    Seq(AttributeReference("plan", StringType, nullable = false)())
    -}
    -
    -/**
    - * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command.
    - */
    -case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean)
    -  extends Command
    -
    -/**
    - * Returned for the "UNCACHE TABLE tableName" command.
    - */
    -case class UncacheTableCommand(tableName: String) extends Command
    -
    -/**
    - * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command.
    - * @param table The table to be described.
    - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false.
    - *                   It is effective only when the table is a Hive table.
    - */
    -case class DescribeCommand(
    -    table: LogicalPlan,
    -    isExtended: Boolean) extends Command {
    -  override def output = Seq(
    -    // Column names are based on Hive.
    -    AttributeReference("col_name", StringType, nullable = false)(),
    -    AttributeReference("data_type", StringType, nullable = false)(),
    -    AttributeReference("comment", StringType, nullable = false)())
    -}
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    index 9962937277da..6c575dd727b4 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
    @@ -76,7 +76,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       @transient
       protected[sql] val sqlParser = {
         val fallback = new catalyst.SqlParser
    -    new catalyst.SparkSQLParser(fallback(_))
    +    new SparkSQLParser(fallback(_))
       }
     
       protected[sql] def parseSql(sql: String): LogicalPlan = {
    @@ -329,7 +329,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
     
         def strategies: Seq[Strategy] =
           extraStrategies ++ (
    -      CommandStrategy ::
           DataSourceStrategy ::
           TakeOrdered ::
           HashAggregation ::
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
    new file mode 100644
    index 000000000000..65358b7d4ea8
    --- /dev/null
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
    @@ -0,0 +1,97 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql
    +
    +import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser}
    +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
    +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
    +import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand}
    +
    +import scala.util.parsing.combinator.RegexParsers
    +
    +/**
    + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
    + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
    + *
    + * @param fallback A function that parses an input string to a logical plan
    + */
    +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
    +
    +  // A parser for the key-value part of the "SET [key = [value ]]" syntax
    +  private object SetCommandParser extends RegexParsers {
    +    private val key: Parser[String] = "(?m)[^=]+".r
    +
    +    private val value: Parser[String] = "(?m).*$".r
    +
    +    private val output: Seq[Attribute] = Seq(AttributeReference("", StringType, nullable = false)())
    +
    +    private val pair: Parser[LogicalPlan] =
    +      (key ~ ("=".r ~> value).?).? ^^ {
    +        case None => SetCommand(None, output)
    +        case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output)
    +      }
    +
    +    def apply(input: String): LogicalPlan = parseAll(pair, input) match {
    +      case Success(plan, _) => plan
    +      case x => sys.error(x.toString)
    +    }
    +  }
    +
    +  protected val AS      = Keyword("AS")
    +  protected val CACHE   = Keyword("CACHE")
    +  protected val LAZY    = Keyword("LAZY")
    +  protected val SET     = Keyword("SET")
    +  protected val TABLE   = Keyword("TABLE")
    +  protected val UNCACHE = Keyword("UNCACHE")
    +
    +  protected implicit def asParser(k: Keyword): Parser[String] =
    +    lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
    +
    +  private val reservedWords: Seq[String] =
    +    this
    +      .getClass
    +      .getMethods
    +      .filter(_.getReturnType == classOf[Keyword])
    +      .map(_.invoke(this).asInstanceOf[Keyword].str)
    +
    +  override val lexical = new SqlLexical(reservedWords)
    +
    +  override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
    +
    +  private lazy val cache: Parser[LogicalPlan] =
    +    CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
    +      case isLazy ~ tableName ~ plan =>
    +        CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
    +    }
    +
    +  private lazy val uncache: Parser[LogicalPlan] =
    +    UNCACHE ~ TABLE ~> ident ^^ {
    +      case tableName => UncacheTableCommand(tableName)
    +    }
    +
    +  private lazy val set: Parser[LogicalPlan] =
    +    SET ~> restInput ^^ {
    +      case input => SetCommandParser(input)
    +    }
    +
    +  private lazy val others: Parser[LogicalPlan] =
    +    wholeInput ^^ {
    +      case input => fallback(input)
    +    }
    +
    +}
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
    index ce878c137e62..99b6611d3bbc 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
    @@ -259,6 +259,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         def numPartitions = self.numPartitions
     
         def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    +      case r: RunnableCommand => ExecutedCommand(r) :: Nil
    +
           case logical.Distinct(child) =>
             execution.Distinct(partial = false,
               execution.Distinct(partial = true, planLater(child))) :: Nil
    @@ -308,22 +310,4 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           case _ => Nil
         }
       }
    -
    -  case object CommandStrategy extends Strategy {
    -    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    -      case r: RunnableCommand => ExecutedCommand(r) :: Nil
    -      case logical.SetCommand(kv) =>
    -        Seq(ExecutedCommand(execution.SetCommand(kv, plan.output)))
    -      case logical.ExplainCommand(logicalPlan, extended) =>
    -        Seq(ExecutedCommand(
    -          execution.ExplainCommand(logicalPlan, plan.output, extended)))
    -      case logical.CacheTableCommand(tableName, optPlan, isLazy) =>
    -        Seq(ExecutedCommand(
    -          execution.CacheTableCommand(tableName, optPlan, isLazy)))
    -      case logical.UncacheTableCommand(tableName) =>
    -        Seq(ExecutedCommand(
    -          execution.UncacheTableCommand(tableName)))
    -      case _ => Nil
    -    }
    -  }
     }
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
    index b8fa4b019953..0d765c4c92f8 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
    @@ -113,7 +113,7 @@ case class SetCommand(
     @DeveloperApi
     case class ExplainCommand(
         logicalPlan: LogicalPlan,
    -    override val output: Seq[Attribute], extended: Boolean) extends RunnableCommand {
    +    override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand {
     
       // Run through the optimizer to generate the physical plan.
       override def run(sqlContext: SQLContext) = try {
    diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    index 5550183621fb..80733ea1db93 100644
    --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    @@ -33,8 +33,8 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
     import org.apache.hive.service.cli.session.HiveSession
     
     import org.apache.spark.Logging
    -import org.apache.spark.sql.catalyst.plans.logical.SetCommand
     import org.apache.spark.sql.catalyst.types._
    +import org.apache.spark.sql.execution.SetCommand
     import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
     import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
     import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}
    @@ -190,7 +190,7 @@ private[hive] class SparkExecuteStatementOperation(
           result = hiveContext.sql(statement)
           logDebug(result.queryExecution.toString())
           result.queryExecution.logical match {
    -        case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) =>
    +        case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) =>
               sessionToActivePool(parentSession.getSessionHandle) = value
               logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
             case _ =>
    diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    index 798a690a2042..19d85140071e 100644
    --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    @@ -31,7 +31,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
     import org.apache.hive.service.cli.session.HiveSession
     
     import org.apache.spark.Logging
    -import org.apache.spark.sql.catalyst.plans.logical.SetCommand
    +import org.apache.spark.sql.execution.SetCommand
     import org.apache.spark.sql.catalyst.types._
     import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
     import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
    @@ -161,7 +161,7 @@ private[hive] class SparkExecuteStatementOperation(
           result = hiveContext.sql(statement)
           logDebug(result.queryExecution.toString())
           result.queryExecution.logical match {
    -        case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) =>
    +        case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) =>
               sessionToActivePool(parentSession.getSessionHandle) = value
               logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
             case _ =>
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    index 1648fa826b90..02eac43b2103 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
    @@ -38,8 +38,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
     import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry}
     import org.apache.spark.sql.catalyst.plans.logical._
     import org.apache.spark.sql.catalyst.types.DecimalType
    -import org.apache.spark.sql.catalyst.types.decimal.Decimal
    -import org.apache.spark.sql.execution.{SparkPlan, ExecutedCommand, ExtractPythonUdfs, QueryExecutionException}
    +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException}
     import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand}
     import org.apache.spark.sql.sources.DataSourceStrategy
     
    @@ -340,7 +339,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
     
         override def strategies: Seq[Strategy] = extraStrategies ++ Seq(
           DataSourceStrategy,
    -      CommandStrategy,
           HiveCommandStrategy(self),
           TakeOrdered,
           ParquetOperations,
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    index c2ab3579c1f9..28de03c38997 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    @@ -24,8 +24,8 @@ import org.apache.hadoop.hive.ql.lib.Node
     import org.apache.hadoop.hive.ql.metadata.Table
     import org.apache.hadoop.hive.ql.parse._
     import org.apache.hadoop.hive.ql.plan.PlanUtils
    +import org.apache.spark.sql.SparkSQLParser
     
    -import org.apache.spark.sql.catalyst.SparkSQLParser
     import org.apache.spark.sql.catalyst.analysis._
     import org.apache.spark.sql.catalyst.expressions._
     import org.apache.spark.sql.catalyst.plans._
    @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical
     import org.apache.spark.sql.catalyst.plans.logical._
     import org.apache.spark.sql.catalyst.types._
     import org.apache.spark.sql.catalyst.types.decimal.Decimal
    +import org.apache.spark.sql.execution.ExplainCommand
     import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable}
     
     /* Implicit conversions */
    @@ -45,6 +46,22 @@ import scala.collection.JavaConversions._
      */
     private[hive] case object NativePlaceholder extends Command
     
    +/**
    + * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command.
    + * @param table The table to be described.
    + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false.
    + *                   It is effective only when the table is a Hive table.
    + */
    +case class DescribeCommand(
    +    table: LogicalPlan,
    +    isExtended: Boolean) extends Command {
    +  override def output = Seq(
    +    // Column names are based on Hive.
    +    AttributeReference("col_name", StringType, nullable = false)(),
    +    AttributeReference("data_type", StringType, nullable = false)(),
    +    AttributeReference("comment", StringType, nullable = false)())
    +}
    +
     /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
     private[hive] object HiveQl {
       protected val nativeCommands = Seq(
    @@ -457,17 +474,23 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
         // Just fake explain for any of the native commands.
         case Token("TOK_EXPLAIN", explainArgs)
           if noExplainCommands.contains(explainArgs.head.getText) =>
    -      ExplainCommand(NoRelation)
    +      ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)()))
         case Token("TOK_EXPLAIN", explainArgs)
           if "TOK_CREATETABLE" == explainArgs.head.getText =>
           val Some(crtTbl) :: _ :: extended :: Nil =
             getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs)
    -      ExplainCommand(nodeToPlan(crtTbl), extended != None)
    +      ExplainCommand(
    +        nodeToPlan(crtTbl),
    +        Seq(AttributeReference("plan", StringType,nullable = false)()),
    +        extended != None)
         case Token("TOK_EXPLAIN", explainArgs) =>
           // Ignore FORMATTED if present.
           val Some(query) :: _ :: extended :: Nil =
             getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
    -      ExplainCommand(nodeToPlan(query), extended != None)
    +      ExplainCommand(
    +        nodeToPlan(query),
    +        Seq(AttributeReference("plan", StringType, nullable = false)()),
    +        extended != None)
     
         case Token("TOK_DESCTABLE", describeArgs) =>
           // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
    index d3f6381b69a4..c439b9ebfe10 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
    @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.planning._
     import org.apache.spark.sql.catalyst.plans._
     import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
     import org.apache.spark.sql.catalyst.types.StringType
    +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
     import org.apache.spark.sql.execution._
     import org.apache.spark.sql.hive
     import org.apache.spark.sql.hive.execution._
    @@ -209,14 +210,14 @@ private[hive] trait HiveStrategies {
     
       case class HiveCommandStrategy(context: HiveContext) extends Strategy {
         def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    -      case describe: logical.DescribeCommand =>
    +      case describe: DescribeCommand =>
             val resolvedTable = context.executePlan(describe.table).analyzed
             resolvedTable match {
               case t: MetastoreRelation =>
                 ExecutedCommand(
                   DescribeHiveTableCommand(t, describe.output, describe.isExtended)) :: Nil
               case o: LogicalPlan =>
    -            ExecutedCommand(DescribeCommand(planLater(o), describe.output)) :: Nil
    +            ExecutedCommand(RunnableDescribeCommand(planLater(o), describe.output)) :: Nil
             }
     
           case _ => Nil
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    index 8f2311cf83eb..1358a0eccb35 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
    @@ -34,8 +34,9 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe
     import org.apache.spark.{SparkConf, SparkContext}
     import org.apache.spark.util.Utils
     import org.apache.spark.sql.catalyst.analysis._
    -import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan}
    +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
     import org.apache.spark.sql.catalyst.util._
    +import org.apache.spark.sql.execution.CacheTableCommand
     import org.apache.spark.sql.hive._
     import org.apache.spark.sql.SQLConf
     import org.apache.spark.sql.hive.execution.HiveNativeCommand
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
    index 4104df8f8e02..f8a957d55d57 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
    @@ -22,6 +22,8 @@ import java.io._
     import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
     
     import org.apache.spark.Logging
    +import org.apache.spark.sql.execution.{SetCommand, ExplainCommand}
    +import org.apache.spark.sql.hive.DescribeCommand
     import org.apache.spark.sql.catalyst.planning.PhysicalOperation
     import org.apache.spark.sql.catalyst.plans.logical._
     import org.apache.spark.sql.catalyst.util._
    
    From 77106df69147aba5eb1784adb84e2b574927c6de Mon Sep 17 00:00:00 2001
    From: Yanbo Liang 
    Date: Sat, 10 Jan 2015 14:16:37 -0800
    Subject: [PATCH 111/215] SPARK-4963 [SQL] Add copy to SQL's Sample operator
    
    https://issues.apache.org/jira/browse/SPARK-4963
    SchemaRDD.sample() return wrong results due to GapSamplingIterator operating on mutable row.
    HiveTableScan make RDD with SpecificMutableRow and SchemaRDD.sample() will return GapSamplingIterator for iterating.
    
    override def next(): T = {
        val r = data.next()
        advance
        r
      }
    
    GapSamplingIterator.next() return the current underlying element and assigned it to r.
    However if the underlying iterator is mutable row just like what HiveTableScan returned, underlying iterator and r will point to the same object.
    After advance operation, we drop some underlying elments and it also changed r which is not expected. Then we return the wrong value different from initial r.
    
    To fix this issue, the most direct way is to make HiveTableScan return mutable row with copy just like the initial commit that I have made. This solution will make HiveTableScan can not get the full advantage of reusable MutableRow, but it can make sample operation return correct result.
    Further more, we need to investigate  GapSamplingIterator.next() and make it can implement copy operation inside it. To achieve this, we should define every elements that RDD can store implement the function like cloneable and it will make huge change.
    
    Author: Yanbo Liang 
    
    Closes #3827 from yanbohappy/spark-4963 and squashes the following commits:
    
    0912ca0 [Yanbo Liang] code format keep
    65c4e7c [Yanbo Liang] import file and clear annotation
    55c7c56 [Yanbo Liang] better output of test case
    cea7e2e [Yanbo Liang] SchemaRDD add copy operation before Sample operator
    e840829 [Yanbo Liang] HiveTableScan return mutable row with copy
    ---
     .../apache/spark/sql/execution/basicOperators.scala  |  2 +-
     .../spark/sql/hive/execution/SQLQuerySuite.scala     | 12 ++++++++++++
     2 files changed, 13 insertions(+), 1 deletion(-)
    
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
    index e53723c17656..16ca4be5587c 100644
    --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
    +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
    @@ -70,7 +70,7 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
       override def output = child.output
     
       // TODO: How to pick seed?
    -  override def execute() = child.execute().sample(withReplacement, fraction, seed)
    +  override def execute() = child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
     }
     
     /**
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
    index 5d0fb7237011..c1c3683f84ab 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
    @@ -21,6 +21,7 @@ import org.apache.spark.sql.QueryTest
     
     import org.apache.spark.sql.Row
     import org.apache.spark.sql.hive.test.TestHive._
    +import org.apache.spark.util.Utils
     
     case class Nested1(f1: Nested2)
     case class Nested2(f2: Nested3)
    @@ -202,4 +203,15 @@ class SQLQuerySuite extends QueryTest {
         checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"),
           sql("SELECT distinct key FROM src order by key").collect().toSeq)
       }
    +
    +  test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") {
    +    sql("SELECT * FROM src WHERE key % 2 = 0")
    +      .sample(withReplacement = false, fraction = 0.3)
    +      .registerTempTable("sampled")
    +    (1 to 10).foreach { i =>
    +      checkAnswer(
    +        sql("SELECT * FROM sampled WHERE key % 2 = 1"),
    +        Seq.empty[Row])
    +    }
    +  }
     }
    
    From 3684fd21e1ffdc0adaad8ff6b31394b637e866ce Mon Sep 17 00:00:00 2001
    From: Michael Armbrust 
    Date: Sat, 10 Jan 2015 14:25:45 -0800
    Subject: [PATCH 112/215] [SPARK-5187][SQL] Fix caching of tables with HiveUDFs
     in the WHERE clause
    
    Author: Michael Armbrust 
    
    Closes #3987 from marmbrus/hiveUdfCaching and squashes the following commits:
    
    8bca2fa [Michael Armbrust] [SPARK-5187][SQL] Fix caching of tables with HiveUDFs in the WHERE clause
    ---
     .../scala/org/apache/spark/sql/hive/CachedTableSuite.scala  | 6 ++++++
     .../src/main/scala/org/apache/spark/sql/hive/Shim12.scala   | 2 +-
     .../src/main/scala/org/apache/spark/sql/hive/Shim13.scala   | 2 +-
     3 files changed, 8 insertions(+), 2 deletions(-)
    
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
    index 2060e1f1a7a4..f95a6b43af35 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
    @@ -158,4 +158,10 @@ class CachedTableSuite extends QueryTest {
         uncacheTable("src")
         assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
       }
    +
    +  test("CACHE TABLE with Hive UDF") {
    +    sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1")
    +    assertCached(table("udfTest"))
    +    uncacheTable("udfTest")
    +  }
     }
    diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
    index 2d01a8506751..25fdf5c5f3da 100644
    --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
    +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
    @@ -44,7 +44,7 @@ import scala.language.implicitConversions
     
     import org.apache.spark.sql.catalyst.types.DecimalType
     
    -class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable {
    +case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable {
       // for Serialization
       def this() = this(null)
     
    diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
    index b78c75798e98..e47002cb0b8c 100644
    --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
    +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
    @@ -53,7 +53,7 @@ import scala.language.implicitConversions
      *
      * @param functionClassName UDF class name
      */
    -class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
    +case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
       // for Serialization
       def this() = this(null)
     
    
    From 0ca51cc31d5dd16aa458956e582eec58bbc31711 Mon Sep 17 00:00:00 2001
    From: YanTangZhai 
    Date: Sat, 10 Jan 2015 15:05:23 -0800
    Subject: [PATCH 113/215] [SPARK-4692] [SQL] Support ! boolean logic operator
     like NOT
    
    Support ! boolean logic operator like NOT in sql as follows
    select * from for_test where !(col1 > col2)
    
    Author: YanTangZhai 
    Author: Michael Armbrust 
    
    Closes #3555 from YanTangZhai/SPARK-4692 and squashes the following commits:
    
    1a9f605 [YanTangZhai] Update HiveQuerySuite.scala
    7c03c68 [YanTangZhai] Merge pull request #23 from apache/master
    992046e [YanTangZhai] Update HiveQuerySuite.scala
    ea618f4 [YanTangZhai] Update HiveQuerySuite.scala
    192411d [YanTangZhai] Merge pull request #17 from YanTangZhai/master
    e4c2c0a [YanTangZhai] Merge pull request #15 from apache/master
    1e1ebb4 [YanTangZhai] Update HiveQuerySuite.scala
    efc4210 [YanTangZhai] Update HiveQuerySuite.scala
    bd2c444 [YanTangZhai] Update HiveQuerySuite.scala
    1893956 [YanTangZhai] Merge pull request #14 from marmbrus/pr/3555
    59e4de9 [Michael Armbrust] make hive test
    718afeb [YanTangZhai] Merge pull request #12 from apache/master
    950b21e [YanTangZhai] Update HiveQuerySuite.scala
    74175b4 [YanTangZhai] Update HiveQuerySuite.scala
    92242c7 [YanTangZhai] Update HiveQl.scala
    6e643f8 [YanTangZhai] Merge pull request #11 from apache/master
    e249846 [YanTangZhai] Merge pull request #10 from apache/master
    d26d982 [YanTangZhai] Merge pull request #9 from apache/master
    76d4027 [YanTangZhai] Merge pull request #8 from apache/master
    03b62b0 [YanTangZhai] Merge pull request #7 from apache/master
    8a00106 [YanTangZhai] Merge pull request #6 from apache/master
    cbcba66 [YanTangZhai] Merge pull request #3 from apache/master
    cdef539 [YanTangZhai] Merge pull request #1 from apache/master
    ---
     .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 +
     .../golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e  | 1 +
     .../apache/spark/sql/hive/execution/HiveQuerySuite.scala  | 8 ++++++++
     3 files changed, 10 insertions(+)
     create mode 100644 sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e
    
    diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    index 28de03c38997..34622b5f5787 100644
    --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
    @@ -1111,6 +1111,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
         case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right))
         case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
         case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
    +    case Token("!", child :: Nil) => Not(nodeToExpr(child))
     
         /* Case statements */
         case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
    diff --git a/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e b/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e
    new file mode 100644
    index 000000000000..d00491fd7e5b
    --- /dev/null
    +++ b/sql/hive/src/test/resources/golden/! operator-0-81d1a187c7f4a6337baf081510a5dc5e	
    @@ -0,0 +1 @@
    +1
    diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
    index fb6da33e88ef..700a45edb11d 100644
    --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
    +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
    @@ -62,6 +62,14 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
           sql("SHOW TABLES")
         }
       }
    +  
    +  createQueryTest("! operator",
    +    """
    +      |SELECT a FROM (
    +      |  SELECT 1 AS a FROM src LIMIT 1 UNION ALL
    +      |  SELECT 2 AS a FROM src LIMIT 1) table
    +      |WHERE !(a>1)
    +    """.stripMargin)
     
       createQueryTest("constant object inspector for generic udf",
         """SELECT named_struct(
    
    From f0d558b6e6ec0c97280d5844c98fb92c24954cbb Mon Sep 17 00:00:00 2001
    From: CodingCat 
    Date: Sat, 10 Jan 2015 15:35:41 -0800
    Subject: [PATCH 114/215] [SPARK-5181] do not print writing WAL log when WAL is
     disabled
    
    https://issues.apache.org/jira/browse/SPARK-5181
    
    Currently, even the logManager is not created, we still see the log entry
    s"Writing to log $record"
    
    a simple fix to make log more accurate
    
    Author: CodingCat 
    
    Closes #3985 from CodingCat/SPARK-5181 and squashes the following commits:
    
    0e27dc5 [CodingCat] do not print writing WAL log when WAL is disabled
    ---
     .../spark/streaming/scheduler/ReceivedBlockTracker.scala    | 6 ++++--
     1 file changed, 4 insertions(+), 2 deletions(-)
    
    diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
    index 2ce458cddec1..c3d9d7b6813d 100644
    --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
    +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
    @@ -203,9 +203,11 @@ private[streaming] class ReceivedBlockTracker(
     
       /** Write an update to the tracker to the write ahead log */
       private def writeToLog(record: ReceivedBlockTrackerLogEvent) {
    -    logDebug(s"Writing to log $record")
    -    logManagerOption.foreach { logManager =>
    +    if (isLogManagerEnabled) {
    +      logDebug(s"Writing to log $record")
    +      logManagerOption.foreach { logManager =>
             logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record)))
    +      }
         }
       }
     
    
    From 8a29dc716e3452fdf546852ddc18238018b73891 Mon Sep 17 00:00:00 2001
    From: GuoQiang Li 
    Date: Sat, 10 Jan 2015 15:38:43 -0800
    Subject: [PATCH 115/215] [Minor]Resolve sbt warnings during build
     (MQTTStreamSuite.scala).
    
    cc andrewor14
    
    Author: GuoQiang Li 
    
    Closes #3989 from witgo/MQTTStreamSuite and squashes the following commits:
    
    a6e967e [GuoQiang Li] Resolve sbt warnings during build (MQTTStreamSuite.scala).
    ---
     .../scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala  | 1 +
     1 file changed, 1 insertion(+)
    
    diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    index 39eb8b183488..30727dfa6443 100644
    --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
    @@ -20,6 +20,7 @@ package org.apache.spark.streaming.mqtt
     import java.net.{URI, ServerSocket}
     
     import scala.concurrent.duration._
    +import scala.language.postfixOps
     
     import org.apache.activemq.broker.{TransportConnector, BrokerService}
     import org.eclipse.paho.client.mqttv3._
    
    From 92d9a704ce1232bddc570bca13758b11ff9ddb1f Mon Sep 17 00:00:00 2001
    From: wangfei 
    Date: Sat, 10 Jan 2015 17:04:56 -0800
    Subject: [PATCH 116/215] [SPARK-4871][SQL] Show sql statement in spark ui when
     run sql with spark-sql
    
    Author: wangfei 
    
    Closes #3718 from scwf/sparksqlui and squashes the following commits:
    
    e0d6b5d [wangfei] format fix
    383b505 [wangfei] fix conflicts
    4d2038a [wangfei] using setJobDescription
    df79837 [wangfei] fix compile error
    92ce834 [wangfei] show sql statement in spark ui when run sql use spark-sql
    ---
     core/src/main/scala/org/apache/spark/SparkContext.scala      | 1 -
     .../spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala | 1 +
     .../org/apache/spark/sql/hive/thriftserver/Shim12.scala      | 5 +----
     .../org/apache/spark/sql/hive/thriftserver/Shim13.scala      | 5 +----
     4 files changed, 3 insertions(+), 9 deletions(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
    index 3bf3acd245d8..ff5d796ee276 100644
    --- a/core/src/main/scala/org/apache/spark/SparkContext.scala
    +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
    @@ -458,7 +458,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
         Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
     
       /** Set a human readable description of the current job. */
    -  @deprecated("use setJobGroup", "0.8.1")
       def setJobDescription(value: String) {
         setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
       }
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala
    index 7a3d76c61c3a..59f3a7576808 100644
    --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala
    +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala
    @@ -53,6 +53,7 @@ private[hive] abstract class AbstractSparkSQLDriver(
       override def run(command: String): CommandProcessorResponse = {
         // TODO unify the error code
         try {
    +      context.sparkContext.setJobDescription(command)
           val execution = context.executePlan(context.sql(command).logicalPlan)
           hiveResponse = execution.stringResult()
           tableSchema = getResultSetSchema(execution)
    diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    index 80733ea1db93..742acba58d77 100644
    --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
    @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList, Map => JMap}
     
     import scala.collection.JavaConversions._
     import scala.collection.mutable.{ArrayBuffer, Map => SMap}
    -import scala.math._
     
     import org.apache.hadoop.hive.common.`type`.HiveDecimal
     import org.apache.hadoop.hive.metastore.api.FieldSchema
    @@ -195,9 +194,7 @@ private[hive] class SparkExecuteStatementOperation(
               logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
             case _ =>
           }
    -
    -      val groupId = round(random * 1000000).toString
    -      hiveContext.sparkContext.setJobGroup(groupId, statement)
    +      hiveContext.sparkContext.setJobDescription(statement)
           sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
             hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
           }
    diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    index 19d85140071e..b82156427a88 100644
    --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
    @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
     
     import scala.collection.JavaConversions._
     import scala.collection.mutable.{ArrayBuffer, Map => SMap}
    -import scala.math._
     
     import org.apache.hadoop.hive.metastore.api.FieldSchema
     import org.apache.hadoop.security.UserGroupInformation
    @@ -166,9 +165,7 @@ private[hive] class SparkExecuteStatementOperation(
               logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
             case _ =>
           }
    -
    -      val groupId = round(random * 1000000).toString
    -      hiveContext.sparkContext.setJobGroup(groupId, statement)
    +      hiveContext.sparkContext.setJobDescription(statement)
           sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
             hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
           }
    
    From d22a31f5e84e27e27a059f540d08a8a441fc17fa Mon Sep 17 00:00:00 2001
    From: scwf 
    Date: Sat, 10 Jan 2015 17:07:34 -0800
    Subject: [PATCH 117/215] [SPARK-5029][SQL] Enable from follow multiple
     brackets
    
    Enable from follow multiple brackets:
    ```
    select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1
    ```
    
    Author: scwf 
    
    Closes #3853 from scwf/from and squashes the following commits:
    
    14f110a [scwf] enable from follow multiple brackets
    ---
     .../apache/spark/sql/catalyst/SqlParser.scala   |  2 +-
     .../org/apache/spark/sql/SQLQuerySuite.scala    | 17 +++++++++++++++++
     2 files changed, 18 insertions(+), 1 deletion(-)
    
    diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    index fc7b8745590d..5d974df98b69 100755
    --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
    @@ -125,7 +125,7 @@ class SqlParser extends AbstractSparkSQLParser {
       }
     
       protected lazy val start: Parser[LogicalPlan] =
    -    ( select *
    +    ( (select | ("(" ~> select <~ ")")) *
           ( UNION ~ ALL        ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
           | INTERSECT          ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
           | EXCEPT             ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
    diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
    index add4e218a22e..d9de5686dce4 100644
    --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
    +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
    @@ -272,6 +272,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
           mapData.collect().take(1).toSeq)
       }
     
    +  test("from follow multiple brackets") {
    +    checkAnswer(sql(
    +      "select key from ((select * from testData limit 1) union all (select * from testData limit 1)) x limit 1"),
    +      1
    +    )
    +
    +    checkAnswer(sql(
    +      "select key from (select * from testData) x limit 1"),
    +      1
    +    )
    +
    +    checkAnswer(sql(
    +      "select key from (select * from testData limit 1 union all select * from testData limit 1) x limit 1"),
    +      1
    +    )
    +  }
    +
       test("average") {
         checkAnswer(
           sql("SELECT AVG(a) FROM testData2"),
    
    From 33132609096d7fa45001c6a67724ec60bcaefaa9 Mon Sep 17 00:00:00 2001
    From: "Joseph K. Bradley" 
    Date: Sat, 10 Jan 2015 17:25:39 -0800
    Subject: [PATCH 118/215] [SPARK-5032] [graphx] Remove GraphX MIMA exclude for
     1.3
    
    Since GraphX is no longer alpha as of 1.2, MimaExcludes should not exclude GraphX for 1.3
    
    Here are the individual excludes I had to add + the associated commits:
    
    ```
                // SPARK-4444
                ProblemFilters.exclude[IncompatibleResultTypeProblem](
                  "org.apache.spark.graphx.EdgeRDD.fromEdges"),
                ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.EdgeRDD.filter"),
                ProblemFilters.exclude[IncompatibleResultTypeProblem](
                  "org.apache.spark.graphx.impl.EdgeRDDImpl.filter"),
    ```
    [https://github.com/apache/spark/commit/9ac2bb18ede2e9f73c255fa33445af89aaf8a000]
    
    ```
                // SPARK-3623
                ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.checkpoint")
    ```
    [https://github.com/apache/spark/commit/e895e0cbecbbec1b412ff21321e57826d2d0a982]
    
    ```
                // SPARK-4620
                ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.unpersist"),
    ```
    [https://github.com/apache/spark/commit/8817fc7fe8785d7b11138ca744f22f7e70f1f0a0]
    
    CC: rxin
    
    Author: Joseph K. Bradley 
    
    Closes #3856 from jkbradley/graphx-mima and squashes the following commits:
    
    1eea2f6 [Joseph K. Bradley] moved cleanup to run-tests
    527ccd9 [Joseph K. Bradley] fixed jenkins script to remove ivy2 cache
    802e252 [Joseph K. Bradley] Removed GraphX MIMA excludes and added line to clear spark from .m2 dir before Jenkins tests.  This may not work yet...
    30f8bb4 [Joseph K. Bradley] added individual mima excludes for graphx
    a3fea42 [Joseph K. Bradley] removed graphx mima exclude for 1.3
    ---
     dev/run-tests              | 4 +++-
     project/MimaExcludes.scala | 1 -
     2 files changed, 3 insertions(+), 2 deletions(-)
    
    diff --git a/dev/run-tests b/dev/run-tests
    index 20603fc08923..2257a566bb1b 100755
    --- a/dev/run-tests
    +++ b/dev/run-tests
    @@ -21,8 +21,10 @@
     FWDIR="$(cd "`dirname $0`"/..; pwd)"
     cd "$FWDIR"
     
    -# Remove work directory
    +# Clean up work directory and caches
     rm -rf ./work
    +rm -rf ~/.ivy2/local/org.apache.spark
    +rm -rf ~/.ivy2/cache/org.apache.spark
     
     source "$FWDIR/dev/run-tests-codes.sh"
     
    diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
    index 31d4c317ae56..51e8bd4cf641 100644
    --- a/project/MimaExcludes.scala
    +++ b/project/MimaExcludes.scala
    @@ -36,7 +36,6 @@ object MimaExcludes {
             case v if v.startsWith("1.3") =>
               Seq(
                 MimaBuild.excludeSparkPackage("deploy"),
    -            MimaBuild.excludeSparkPackage("graphx"),
                 // These are needed if checking against the sbt build, since they are part of
                 // the maven-generated artifacts in the 1.2 build.
                 MimaBuild.excludeSparkPackage("unused"),
    
    From 1656aae2b4e8b026f8cfe782519f72d32ed2b291 Mon Sep 17 00:00:00 2001
    From: lewuathe 
    Date: Sun, 11 Jan 2015 13:50:42 -0800
    Subject: [PATCH 119/215] [SPARK-5073] spark.storage.memoryMapThreshold have
     two default value
    
    Because major OS page sizes is about 4KB, the default value of spark.storage.memoryMapThreshold is integrated to 2 * 4096
    
    Author: lewuathe 
    
    Closes #3900 from Lewuathe/integrate-memoryMapThreshold and squashes the following commits:
    
    e417acd [lewuathe] [SPARK-5073] Update docs/configuration
    834aba4 [lewuathe] [SPARK-5073] Fix style
    adcea33 [lewuathe] [SPARK-5073] Integrate memory map threshold to 2MB
    fcce2e5 [lewuathe] [SPARK-5073] spark.storage.memoryMapThreshold have two default value
    ---
     core/src/main/scala/org/apache/spark/storage/DiskStore.scala | 3 ++-
     docs/configuration.md                                        | 2 +-
     2 files changed, 3 insertions(+), 2 deletions(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
    index 8dadf6794039..61ef5ff16879 100644
    --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
    +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
    @@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
     private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
       extends BlockStore(blockManager) with Logging {
     
    -  val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
    +  val minMemoryMapBytes = blockManager.conf.getLong(
    +    "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L)
     
       override def getSize(blockId: BlockId): Long = {
         diskManager.getFile(blockId.name).length
    diff --git a/docs/configuration.md b/docs/configuration.md
    index 2add48569bec..f292bfbb7dcd 100644
    --- a/docs/configuration.md
    +++ b/docs/configuration.md
    @@ -678,7 +678,7 @@ Apart from these, the following properties are also available, and may be useful
     
     
    -  
    +  
    -  
    +  
    -  
    -  
    +  
    +  
    @@ -117,6 +116,14 @@ of the most common options to set are:
         (e.g. 512m, 2g).
       
    +
    +  
    +  
    +  
    +
    diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
    index 4f273098c5db..68ab127bcf08 100644
    --- a/docs/running-on-yarn.md
    +++ b/docs/running-on-yarn.md
    @@ -29,6 +29,23 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
         In cluster mode, use spark.driver.memory instead.
       
    +
    +  
    +  
    +  
    +
    +
    +  
    +  
    +  
    +
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    index 032106371cd6..d4eeccf64275 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
    @@ -127,6 +127,7 @@ private[spark] class Client(
         }
         val capability = Records.newRecord(classOf[Resource])
         capability.setMemory(args.amMemory + amMemoryOverhead)
    +    capability.setVirtualCores(args.amCores)
         appContext.setResource(capability)
         appContext
       }
    diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    index 461a9ccd3c21..79bead77ba6e 100644
    --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
    @@ -36,14 +36,18 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
       var numExecutors = DEFAULT_NUMBER_EXECUTORS
       var amQueue = sparkConf.get("spark.yarn.queue", "default")
       var amMemory: Int = 512 // MB
    +  var amCores: Int = 1
       var appName: String = "Spark"
       var priority = 0
       def isClusterMode: Boolean = userClass != null
     
       private var driverMemory: Int = 512 // MB
    +  private var driverCores: Int = 1
       private val driverMemOverheadKey = "spark.yarn.driver.memoryOverhead"
       private val amMemKey = "spark.yarn.am.memory"
       private val amMemOverheadKey = "spark.yarn.am.memoryOverhead"
    +  private val driverCoresKey = "spark.driver.cores"
    +  private val amCoresKey = "spark.yarn.am.cores"
       private val isDynamicAllocationEnabled =
         sparkConf.getBoolean("spark.dynamicAllocation.enabled", false)
     
    @@ -92,19 +96,25 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
             "You must specify at least 1 executor!\n" + getUsageMessage())
         }
         if (isClusterMode) {
    -      for (key <- Seq(amMemKey, amMemOverheadKey)) {
    +      for (key <- Seq(amMemKey, amMemOverheadKey, amCoresKey)) {
             if (sparkConf.contains(key)) {
               println(s"$key is set but does not apply in cluster mode.")
             }
           }
           amMemory = driverMemory
    +      amCores = driverCores
         } else {
    -      if (sparkConf.contains(driverMemOverheadKey)) {
    -        println(s"$driverMemOverheadKey is set but does not apply in client mode.")
    +      for (key <- Seq(driverMemOverheadKey, driverCoresKey)) {
    +        if (sparkConf.contains(key)) {
    +          println(s"$key is set but does not apply in client mode.")
    +        }
           }
           sparkConf.getOption(amMemKey)
             .map(Utils.memoryStringToMb)
             .foreach { mem => amMemory = mem }
    +      sparkConf.getOption(amCoresKey)
    +        .map(_.toInt)
    +        .foreach { cores => amCores = cores }
         }
       }
     
    @@ -140,6 +150,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
               driverMemory = value
               args = tail
     
    +        case ("--driver-cores") :: IntParam(value) :: tail =>
    +          driverCores = value
    +          args = tail
    +
             case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail =>
               if (args(0) == "--num-workers") {
                 println("--num-workers is deprecated. Use --num-executors instead.")
    @@ -198,7 +212,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
     
       private def getUsageMessage(unknownParam: List[String] = null): String = {
         val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
    -    message + """
    +    message +
    +      """
           |Usage: org.apache.spark.deploy.yarn.Client [options]
           |Options:
           |  --jar JAR_PATH           Path to your application's JAR file (required in yarn-cluster
    @@ -209,6 +224,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
           |  --num-executors NUM      Number of executors to start (Default: 2)
           |  --executor-cores NUM     Number of cores for the executors (Default: 1).
           |  --driver-memory MEM      Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)
    +      |  --driver-cores NUM       Number of cores used by the driver (Default: 1).
           |  --executor-memory MEM    Memory per executor (e.g. 1000M, 2G) (Default: 1G)
           |  --name NAME              The name of your application (Default: Spark)
           |  --queue QUEUE            The hadoop queue to use for allocation requests (Default:
    
    From e200ac8e53a533d64a79c18561b557ea445f1cc9 Mon Sep 17 00:00:00 2001
    From: Ye Xianjin 
    Date: Fri, 16 Jan 2015 09:20:53 -0800
    Subject: [PATCH 157/215] [SPARK-5201][CORE] deal with int overflow in the
     ParallelCollectionRDD.slice method
    
    There is an int overflow in the ParallelCollectionRDD.slice method. That's originally reported by SaintBacchus.
    ```
    sc.makeRDD(1 to (Int.MaxValue)).count       // result = 0
    sc.makeRDD(1 to (Int.MaxValue - 1)).count   // result = 2147483646 = Int.MaxValue - 1
    sc.makeRDD(1 until (Int.MaxValue)).count    // result = 2147483646 = Int.MaxValue - 1
    ```
    see https://github.com/apache/spark/pull/2874 for more details.
    This pr try to fix the overflow. However, There's another issue I don't address.
    ```
    val largeRange = Int.MinValue to Int.MaxValue
    largeRange.length // throws java.lang.IllegalArgumentException: -2147483648 to 2147483647 by 1: seqs cannot contain more than Int.MaxValue elements.
    ```
    
    So, the range we feed to sc.makeRDD cannot contain more than Int.MaxValue elements. This is the limitation of Scala. However I think  we may want to support that kind of range. But the fix is beyond this pr.
    
    srowen andrewor14 would you mind take a look at this pr?
    
    Author: Ye Xianjin 
    
    Closes #4002 from advancedxy/SPARk-5201 and squashes the following commits:
    
    96265a1 [Ye Xianjin] Update slice method comment and some responding docs.
    e143d7a [Ye Xianjin] Update inclusive range check for splitting inclusive range.
    b3f5577 [Ye Xianjin] We can include the last element in the last slice in general for inclusive range, hence eliminate the need to check Int.MaxValue or Int.MinValue.
    7d39b9e [Ye Xianjin] Convert the two cases pattern matching to one case.
    651c959 [Ye Xianjin] rename sign to needsInclusiveRange. add some comments
    196f8a8 [Ye Xianjin] Add test cases for ranges end with Int.MaxValue or Int.MinValue
    e66e60a [Ye Xianjin] Deal with inclusive and exclusive ranges in one case. If the range is inclusive and the end of the range is (Int.MaxValue or Int.MinValue), we should use inclusive range instead of exclusive
    ---
     .../scala/org/apache/spark/SparkContext.scala |  7 +++---
     .../spark/rdd/ParallelCollectionRDD.scala     | 21 +++++++---------
     .../rdd/ParallelCollectionSplitSuite.scala    | 25 +++++++++++++++++++
     3 files changed, 37 insertions(+), 16 deletions(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
    index ff5d796ee276..6a354ed4d148 100644
    --- a/core/src/main/scala/org/apache/spark/SparkContext.scala
    +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
    @@ -520,10 +520,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
     
       /** Distribute a local Scala collection to form an RDD.
        *
    -   * @note Parallelize acts lazily. If `seq` is a mutable collection and is
    -   * altered after the call to parallelize and before the first action on the
    -   * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
    -   * the argument to avoid this.
    +   * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
    +   * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
    +   * modified collection. Pass a copy of the argument to avoid this.
        */
       def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
         new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
    diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
    index 87b22de6ae69..f12d0cffaba3 100644
    --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
    +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
    @@ -111,7 +111,8 @@ private object ParallelCollectionRDD {
       /**
        * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
        * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
    -   * it efficient to run Spark over RDDs representing large sets of numbers.
    +   * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
    +   * is an inclusive Range, we use inclusive range for the last slice.
        */
       def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
         if (numSlices < 1) {
    @@ -127,19 +128,15 @@ private object ParallelCollectionRDD {
           })
         }
         seq match {
    -      case r: Range.Inclusive => {
    -        val sign = if (r.step < 0) {
    -          -1
    -        } else {
    -          1
    -        }
    -        slice(new Range(
    -          r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
    -      }
           case r: Range => {
    -        positions(r.length, numSlices).map({
    -          case (start, end) =>
    +        positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
    +          // If the range is inclusive, use inclusive range for the last slice
    +          if (r.isInclusive && index == numSlices - 1) {
    +            new Range.Inclusive(r.start + start * r.step, r.end, r.step)
    +          }
    +          else {
                 new Range(r.start + start * r.step, r.start + end * r.step, r.step)
    +          }
             }).toSeq.asInstanceOf[Seq[Seq[T]]]
           }
           case nr: NumericRange[_] => {
    diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
    index 1b112f1a41ca..cd193ae4f523 100644
    --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
    +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
    @@ -76,6 +76,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
         assert(slices(0).mkString(",") === (0 to 32).mkString(","))
         assert(slices(1).mkString(",") === (33 to 66).mkString(","))
         assert(slices(2).mkString(",") === (67 to 100).mkString(","))
    +    assert(slices(2).isInstanceOf[Range.Inclusive])
       }
     
       test("empty data") {
    @@ -227,4 +228,28 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
         assert(slices.map(_.size).reduceLeft(_+_) === 100)
         assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
       }
    +
    +  test("inclusive ranges with Int.MaxValue and Int.MinValue") {
    +    val data1 = 1 to Int.MaxValue
    +    val slices1 = ParallelCollectionRDD.slice(data1, 3)
    +    assert(slices1.size === 3)
    +    assert(slices1.map(_.size).sum === Int.MaxValue)
    +    assert(slices1(2).isInstanceOf[Range.Inclusive])
    +    val data2 = -2 to Int.MinValue by -1
    +    val slices2 = ParallelCollectionRDD.slice(data2, 3)
    +    assert(slices2.size == 3)
    +    assert(slices2.map(_.size).sum === Int.MaxValue)
    +    assert(slices2(2).isInstanceOf[Range.Inclusive])
    +  }
    +
    +  test("empty ranges with Int.MaxValue and Int.MinValue") {
    +    val data1 = Int.MaxValue until Int.MaxValue
    +    val slices1 = ParallelCollectionRDD.slice(data1, 5)
    +    assert(slices1.size === 5)
    +    for (i <- 0 until 5) assert(slices1(i).size === 0)
    +    val data2 = Int.MaxValue until Int.MaxValue
    +    val slices2 = ParallelCollectionRDD.slice(data2, 5)
    +    assert(slices2.size === 5)
    +    for (i <- 0 until 5) assert(slices2(i).size === 0)
    +  }
     }
    
    From f6b852aade7668c99f37c69f606c64763cb265d2 Mon Sep 17 00:00:00 2001
    From: Sean Owen 
    Date: Fri, 16 Jan 2015 09:28:44 -0800
    Subject: [PATCH 158/215] [DOCS] Fix typo in return type of cogroup
    
    This fixes a simple typo in the cogroup docs noted in http://mail-archives.apache.org/mod_mbox/spark-user/201501.mbox/%3CCAMAsSdJ8_24evMAMg7fOZCQjwimisbYWa9v8BN6Rc3JCauja6wmail.gmail.com%3E
    
    I didn't bother with a JIRA
    
    Author: Sean Owen 
    
    Closes #4072 from srowen/CogroupDocFix and squashes the following commits:
    
    43c850b [Sean Owen] Fix typo in return type of cogroup
    ---
     docs/programming-guide.md | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/docs/programming-guide.md b/docs/programming-guide.md
    index 5e0d5c15d706..0211bbabc113 100644
    --- a/docs/programming-guide.md
    +++ b/docs/programming-guide.md
    @@ -913,7 +913,7 @@ for details.
     
    -  
    +  
    
    From e8422c521bc76bc4b03c337605f136403ea9f64a Mon Sep 17 00:00:00 2001
    From: Kousuke Saruta 
    Date: Fri, 16 Jan 2015 10:05:11 -0800
    Subject: [PATCH 159/215] [SPARK-5231][WebUI] History Server shows wrong job
     submission time.
    
    History Server doesn't show collect job submission time.
    It's because `JobProgressListener` updates job submission time every time `onJobStart` method is invoked from `ReplayListenerBus`.
    
    Author: Kousuke Saruta 
    
    Closes #4029 from sarutak/SPARK-5231 and squashes the following commits:
    
    0af9e22 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5231
    da8bd14 [Kousuke Saruta] Made submissionTime in SparkListenerJobStartas and completionTime in SparkListenerJobEnd as regular Long
    0412a6a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5231
    26b9b99 [Kousuke Saruta] Fixed the test cases
    2d47bd3 [Kousuke Saruta] Fixed to record job submission time and completion time collectly
    ---
     .../apache/spark/scheduler/DAGScheduler.scala | 16 ++++++----
     .../spark/scheduler/SparkListener.scala       |  7 ++++-
     .../apache/spark/ui/jobs/AllJobsPage.scala    | 15 +++++----
     .../spark/ui/jobs/JobProgressListener.scala   | 10 +++---
     .../org/apache/spark/ui/jobs/UIData.scala     |  6 ++--
     .../org/apache/spark/util/JsonProtocol.scala  | 11 +++++--
     .../spark/scheduler/SparkListenerSuite.scala  | 10 +++---
     .../ui/jobs/JobProgressListenerSuite.scala    |  6 ++--
     .../apache/spark/util/JsonProtocolSuite.scala | 31 ++++++++++++++++---
     9 files changed, 77 insertions(+), 35 deletions(-)
    
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    index 8cb15918baa8..3bca59e0646d 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
    @@ -661,7 +661,7 @@ class DAGScheduler(
           // completion events or stage abort
           stageIdToStage -= s.id
           jobIdToStageIds -= job.jobId
    -      listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
    +      listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult))
         }
       }
     
    @@ -710,7 +710,7 @@ class DAGScheduler(
             stage.latestInfo.stageFailed(stageFailedMessage)
             listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
           }
    -      listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
    +      listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
         }
       }
     
    @@ -749,9 +749,11 @@ class DAGScheduler(
           logInfo("Missing parents: " + getMissingParentStages(finalStage))
           val shouldRunLocally =
             localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
    +      val jobSubmissionTime = clock.getTime()
           if (shouldRunLocally) {
             // Compute very short actions like first() or take() with no parent stages locally.
    -        listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
    +        listenerBus.post(
    +          SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
             runLocally(job)
           } else {
             jobIdToActiveJob(jobId) = job
    @@ -759,7 +761,8 @@ class DAGScheduler(
             finalStage.resultOfJob = Some(job)
             val stageIds = jobIdToStageIds(jobId).toArray
             val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
    -        listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
    +        listenerBus.post(
    +          SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
             submitStage(finalStage)
           }
         }
    @@ -965,7 +968,8 @@ class DAGScheduler(
                       if (job.numFinished == job.numPartitions) {
                         markStageAsFinished(stage)
                         cleanupStateForJobAndIndependentStages(job)
    -                    listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
    +                    listenerBus.post(
    +                      SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded))
                       }
     
                       // taskSucceeded runs some user code that might throw an exception. Make sure
    @@ -1234,7 +1238,7 @@ class DAGScheduler(
         if (ableToCancelStages) {
           job.listener.jobFailed(error)
           cleanupStateForJobAndIndependentStages(job)
    -      listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
    +      listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
         }
       }
     
    diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
    index 4840d8bd2d2f..e5d1eb767e10 100644
    --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
    +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
    @@ -59,6 +59,7 @@ case class SparkListenerTaskEnd(
     @DeveloperApi
     case class SparkListenerJobStart(
         jobId: Int,
    +    time: Long,
         stageInfos: Seq[StageInfo],
         properties: Properties = null)
       extends SparkListenerEvent {
    @@ -68,7 +69,11 @@ case class SparkListenerJobStart(
     }
     
     @DeveloperApi
    -case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
    +case class SparkListenerJobEnd(
    +    jobId: Int,
    +    time: Long,
    +    jobResult: JobResult)
    +  extends SparkListenerEvent
     
     @DeveloperApi
     case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]])
    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
    index 1d1c70187844..81212708ba52 100644
    --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
    +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
    @@ -21,7 +21,6 @@ import scala.xml.{Node, NodeSeq}
     
     import javax.servlet.http.HttpServletRequest
     
    -import org.apache.spark.JobExecutionStatus
     import org.apache.spark.ui.{WebUIPage, UIUtils}
     import org.apache.spark.ui.jobs.UIData.JobUIData
     
    @@ -51,13 +50,13 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
           val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
           val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("")
           val duration: Option[Long] = {
    -        job.startTime.map { start =>
    -          val end = job.endTime.getOrElse(System.currentTimeMillis())
    +        job.submissionTime.map { start =>
    +          val end = job.completionTime.getOrElse(System.currentTimeMillis())
               end - start
             }
           }
           val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
    -      val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown")
    +      val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown")
           val detailUrl =
             "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
           
    @@ -68,7 +67,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
               
    {lastStageDescription}
    {lastStageName} - @@ -101,11 +100,11 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val now = System.currentTimeMillis val activeJobsTable = - jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + jobsTable(activeJobs.sortBy(_.submissionTime.getOrElse(-1L)).reverse) val completedJobsTable = - jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(completedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) val failedJobsTable = - jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(failedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) val shouldShowActiveJobs = activeJobs.nonEmpty val shouldShowCompletedJobs = completedJobs.nonEmpty diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 72935beb3a34..b0d3bed1300b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,14 +153,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val jobData: JobUIData = new JobUIData( jobId = jobStart.jobId, - startTime = Some(System.currentTimeMillis), - endTime = None, + submissionTime = Option(jobStart.time).filter(_ >= 0), stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. jobData.numTasks = { @@ -186,7 +185,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } - jobData.endTime = Some(System.currentTimeMillis()) + jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) + jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData @@ -309,7 +309,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task - // compeletion event is for. Let's just drop it here. This means we might have some speculation + // completion event is for. Let's just drop it here. This means we might have some speculation // tasks on the web ui that's never marked as complete. if (info != null && taskEnd.stageAttemptId != -1) { val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 48fd7caa1a1e..01f7e23212c3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -40,15 +40,15 @@ private[jobs] object UIData { class JobUIData( var jobId: Int = -1, - var startTime: Option[Long] = None, - var endTime: Option[Long] = None, + var submissionTime: Option[Long] = None, + var completionTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, /* Tasks */ // `numTasks` is a potential underestimate of the true number of tasks that this job will run. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. var numTasks: Int = 0, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee3756c226fe..76709a230f83 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -32,6 +32,7 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ +import org.apache.hadoop.hdfs.web.JsonUtil /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- @@ -141,6 +142,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) @@ -150,6 +152,7 @@ private[spark] object JsonProtocol { val jobResult = jobResultToJson(jobEnd.jobResult) ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ ("Job ID" -> jobEnd.jobId) ~ + ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) } @@ -492,6 +495,8 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] + val submissionTime = + Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 @@ -499,13 +504,15 @@ private[spark] object JsonProtocol { .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) } - SparkListenerJobStart(jobId, stageInfos, properties) + SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] + val completionTime = + Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") - SparkListenerJobEnd(jobId, jobResult) + SparkListenerJobEnd(jobId, completionTime, jobResult) } def environmentUpdateFromJson(json: JValue): SparkListenerEnvironmentUpdate = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 24f41bf8cccd..0fb1bdd30d97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -34,6 +34,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 + val jobCompletionTime = 1421191296660L + before { sc = new SparkContext("local", "SparkListenerSuite") } @@ -44,7 +46,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(counter) // Listener bus hasn't started yet, so posting events should not increment counter - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -54,7 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // After listener bus has stopped, posting events should not increment counter bus.stop() - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) // Listener bus must not be started twice @@ -99,7 +101,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(blockingListener) bus.start() - bus.post(SparkListenerJobEnd(0, JobSucceeded)) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() // Listener should be blocked after start @@ -345,7 +347,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.start() // Post events to all listeners, and wait until the queue is drained - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) // The exception should be caught, and the event should be propagated to other listeners diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index f865d8ca04d1..c9417ea1ed8f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L private def createStageStartEvent(stageId: Int) = { val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") @@ -46,12 +48,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, stageInfos) + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded - SparkListenerJobEnd(jobId, result) + SparkListenerJobEnd(jobId, jobCompletionTime, result) } private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 71dfed128985..db400b416291 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -34,6 +34,9 @@ import org.apache.spark.storage._ class JsonProtocolSuite extends FunSuite { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -54,9 +57,9 @@ class JsonProtocolSuite extends FunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) - SparkListenerJobStart(10, stageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) } - val jobEnd = SparkListenerJobEnd(20, JobSucceeded) + val jobEnd = SparkListenerJobEnd(20, jobCompletionTime, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), "Spark Properties" -> Seq(("Job throughput", "80000 jobs/s, regardless of job type")), @@ -247,13 +250,31 @@ class JsonProtocolSuite extends FunSuite { val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) - val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) val expectedJobStart = - SparkListenerJobStart(10, dummyStageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) } + test("SparkListenerJobStart and SparkListenerJobEnd backward compatibility") { + // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. + // Also, SparkListenerJobEnd did not have a "Completion Time" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) + val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) + .removeField({ _._1 == "Submission Time"}) + val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) + + val jobEnd = SparkListenerJobEnd(11, jobCompletionTime, JobSucceeded) + val oldEndEvent = JsonProtocol.jobEndToJson(jobEnd) + .removeField({ _._1 == "Completion Time"}) + val expectedJobEnd = SparkListenerJobEnd(11, -1, JobSucceeded) + assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -1075,6 +1096,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Submission Time": 1421191042750, | "Stage Infos": [ | { | "Stage ID": 1, @@ -1349,6 +1371,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobEnd", | "Job ID": 20, + | "Completion Time": 1421191296660, | "Job Result": { | "Result": "JobSucceeded" | } From ecf943d35342191a1362d80bb26f2a098c152f27 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 16 Jan 2015 12:19:08 -0800 Subject: [PATCH 160/215] [WebUI] Fix collapse of WebUI layout When we decrease the width of browsers, the header of WebUI wraps and collapses like as following image. ![2015-01-11 19 49 37](https://cloud.githubusercontent.com/assets/4736016/5698887/b0b9aeee-99cd-11e4-9020-08f3f0014de0.png) Author: Kousuke Saruta Closes #3995 from sarutak/fixed-collapse-webui-layout and squashes the following commits: 3e60b5b [Kousuke Saruta] Modified line-height property in webui.css 7bfb5fb [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into fixed-collapse-webui-layout 5d83e18 [Kousuke Saruta] Fixed collapse of WebUI layout --- .../main/resources/org/apache/spark/ui/static/webui.css | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 5751964b792c..f02b035a980b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -19,6 +19,7 @@ height: 50px; font-size: 15px; margin-bottom: 15px; + min-width: 1200px } .navbar .navbar-inner { @@ -39,12 +40,12 @@ .navbar .nav > li a { height: 30px; - line-height: 30px; + line-height: 2; } .navbar-text { height: 50px; - line-height: 50px; + line-height: 3.3; } table.sortable thead { @@ -170,7 +171,7 @@ span.additional-metric-title { } .version { - line-height: 30px; + line-height: 2.5; vertical-align: bottom; font-size: 12px; padding: 0; From d05c9ee6e8441e54732e40de45d1d2311307908f Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Fri, 16 Jan 2015 12:56:40 -0800 Subject: [PATCH 161/215] [SPARK-4923][REPL] Add Developer API to REPL to allow re-publishing the REPL jar As requested in [SPARK-4923](https://issues.apache.org/jira/browse/SPARK-4923), I've provided a rough DeveloperApi for the repl. I've only done this for Scala 2.10 because it does not appear that Scala 2.11 is implemented. The Scala 2.11 repl still has the old `scala.tools.nsc` package and the SparkIMain does not appear to have the class server needed for shipping code over (unless this functionality has been moved elsewhere?). I also left alone the `ExecutorClassLoader` and `ConstructorCleaner` as I have no experience working with those classes. This marks the majority of methods in `SparkIMain` as _private_ with a few special cases being _private[repl]_ as other classes within the same package access them. Any public method has been marked with `DeveloperApi` as suggested by pwendell and I took the liberty of writing up a Scaladoc for each one to further elaborate their usage. As the Scala 2.11 REPL [conforms]((https://github.com/scala/scala/pull/2206)) to [JSR-223](http://docs.oracle.com/javase/8/docs/technotes/guides/scripting/), the [Spark Kernel](https://github.com/ibm-et/spark-kernel) uses the SparkIMain of Scala 2.10 in the same manner. So, I've taken care to expose methods predominately related to necessary functionality towards a JSR-223 scripting engine implementation. 1. The ability to _get_ variables from the interpreter (and other information like class/symbol/type) 2. The ability to _put_ variables into the interpreter 3. The ability to _compile_ code 4. The ability to _execute_ code 5. The ability to get contextual information regarding the scripting environment Additional functionality that I marked as exposed included the following: 1. The blocking initialization method (needed to actually start SparkIMain instance) 2. The class server uri (needed to set the _spark.repl.class.uri_ property after initialization), reduced from the entire class server 3. The class output directory (beneficial for tools like ours that need to inspect and use the directory where class files are served) 4. Suppression (quiet/silence) mechanics for output 5. Ability to add a jar to the compile/runtime classpath 6. The reset/close functionality 7. Metric information (last variable assignment, "needed" for extracting results from last execution, real variable name for better debugging) 8. Execution wrapper (useful to have, but debatable) Aside from `SparkIMain`, I updated other classes/traits and their methods in the _repl_ package to be private/package protected where possible. A few odd cases (like the SparkHelper being in the scala.tools.nsc package to expose a private variable) still exist, but I did my best at labelling them. `SparkCommandLine` has proven useful to extract settings and `SparkJLineCompletion` has proven to be useful in implementing auto-completion in the [Spark Kernel](https://github.com/ibm-et/spark-kernel) project. Other than those - and `SparkIMain` - my experience has yielded that other classes/methods are not necessary for interactive applications taking advantage of the REPL API. Tested via the following: $ export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" $ mvn -Phadoop-2.3 -DskipTests clean package && mvn -Phadoop-2.3 test Also did a quick verification that I could start the shell and execute some code: $ ./bin/spark-shell ... scala> val x = 3 x: Int = 3 scala> sc.parallelize(1 to 10).reduce(_+_) ... res1: Int = 55 Author: Chip Senkbeil Author: Chip Senkbeil Closes #4034 from rcsenkbeil/AddDeveloperApiToRepl and squashes the following commits: 053ca75 [Chip Senkbeil] Fixed failed build by adding missing DeveloperApi import c1b88aa [Chip Senkbeil] Added DeveloperApi to public classes in repl 6dc1ee2 [Chip Senkbeil] Added missing method to expose error reporting flag 26fd286 [Chip Senkbeil] Refactored other Scala 2.10 classes and methods to be private/package protected where possible 925c112 [Chip Senkbeil] Added DeveloperApi and Scaladocs to SparkIMain for Scala 2.10 --- .../apache/spark/repl/SparkCommandLine.scala | 9 +- .../apache/spark/repl/SparkExprTyper.scala | 2 +- .../org/apache/spark/repl/SparkHelper.scala | 17 + .../org/apache/spark/repl/SparkILoop.scala | 150 +++-- .../apache/spark/repl/SparkILoopInit.scala | 2 +- .../org/apache/spark/repl/SparkIMain.scala | 592 ++++++++++++++---- .../org/apache/spark/repl/SparkImports.scala | 2 +- .../spark/repl/SparkJLineCompletion.scala | 56 +- .../apache/spark/repl/SparkJLineReader.scala | 4 +- .../spark/repl/SparkMemberHandlers.scala | 2 +- .../spark/repl/SparkRunnerSettings.scala | 3 +- 11 files changed, 644 insertions(+), 195 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 05816941b54b..6480e2d24e04 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -19,14 +19,21 @@ package org.apache.spark.repl import scala.tools.nsc.{Settings, CompilerCommand} import scala.Predef._ +import org.apache.spark.annotation.DeveloperApi /** * Command class enabling Spark-specific command line options (provided by * org.apache.spark.repl.SparkRunnerSettings). + * + * @example new SparkCommandLine(Nil).settings + * + * @param args The list of command line arguments + * @param settings The underlying settings to associate with this set of + * command-line options */ +@DeveloperApi class SparkCommandLine(args: List[String], override val settings: Settings) extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { this(args, new SparkRunnerSettings(error)) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala index f8432c8af6ed..5fb378112ef9 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -15,7 +15,7 @@ import scala.tools.nsc.ast.parser.Tokens.EOF import org.apache.spark.Logging -trait SparkExprTyper extends Logging { +private[repl] trait SparkExprTyper extends Logging { val repl: SparkIMain import repl._ diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala index 5340951d9133..955be17a73b8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala @@ -17,6 +17,23 @@ package scala.tools.nsc +import org.apache.spark.annotation.DeveloperApi + +// NOTE: Forced to be public (and in scala.tools.nsc package) to access the +// settings "explicitParentLoader" method + +/** + * Provides exposure for the explicitParentLoader method on settings instances. + */ +@DeveloperApi object SparkHelper { + /** + * Retrieves the explicit parent loader for the provided settings. + * + * @param settings The settings whose explicit parent loader to retrieve + * + * @return The Optional classloader representing the explicit parent loader + */ + @DeveloperApi def explicitParentLoader(settings: Settings) = settings.explicitParentLoader } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e56b74edba88..72c1a989999b 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -10,6 +10,8 @@ package org.apache.spark.repl import java.net.URL +import org.apache.spark.annotation.DeveloperApi + import scala.reflect.io.AbstractFile import scala.tools.nsc._ import scala.tools.nsc.backend.JavaPlatform @@ -57,20 +59,22 @@ import org.apache.spark.util.Utils * @author Lex Spoon * @version 1.2 */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, - val master: Option[String]) - extends AnyRef - with LoopCommands - with SparkILoopInit - with Logging -{ +@DeveloperApi +class SparkILoop( + private val in0: Option[BufferedReader], + protected val out: JPrintWriter, + val master: Option[String] +) extends AnyRef with LoopCommands with SparkILoopInit with Logging { def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) def this() = this(None, new JPrintWriter(Console.out, true), None) - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ + private var in: InteractiveReader = _ // the input stream from which commands come + + // NOTE: Exposed in package for testing + private[repl] var settings: Settings = _ + + private[repl] var intp: SparkIMain = _ @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i @@ -123,6 +127,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } + // NOTE: Must be public for visibility + @DeveloperApi var sparkContext: SparkContext = _ override def echoCommandMessage(msg: String) { @@ -130,45 +136,45 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } // def isAsync = !settings.Yreplsync.value - def isAsync = false + private[repl] def isAsync = false // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history + private def history = in.history /** The context class loader at the time this object was created */ protected val originalClassLoader = Utils.getContextOrSparkClassLoader // classpath entries added via :cp - var addedClasspath: String = "" + private var addedClasspath: String = "" /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil + private var replayCommandStack: List[String] = Nil /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse + private def replayCommands = replayCommandStack.reverse /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd + private def addReplay(cmd: String) = replayCommandStack ::= cmd - def savingReplayStack[T](body: => T): T = { + private def savingReplayStack[T](body: => T): T = { val saved = replayCommandStack try body finally replayCommandStack = saved } - def savingReader[T](body: => T): T = { + private def savingReader[T](body: => T): T = { val saved = in try body finally in = saved } - def sparkCleanUp(){ + private def sparkCleanUp(){ echo("Stopping spark context.") intp.beQuietDuring { command("sc.stop()") } } /** Close the interpreter and set the var to null. */ - def closeInterpreter() { + private def closeInterpreter() { if (intp ne null) { sparkCleanUp() intp.close() @@ -179,14 +185,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, class SparkILoopInterpreter extends SparkIMain(settings, out) { outer => - override lazy val formatting = new Formatting { + override private[repl] lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt } override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) } - /** Create a new interpreter. */ - def createInterpreter() { + /** + * Constructs a new interpreter. + */ + protected def createInterpreter() { require(settings != null) if (addedClasspath != "") settings.classpath.append(addedClasspath) @@ -207,7 +215,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** print a friendly help message */ - def helpCommand(line: String): Result = { + private def helpCommand(line: String): Result = { if (line == "") helpSummary() else uniqueCommand(line) match { case Some(lc) => echo("\n" + lc.longHelp) @@ -258,7 +266,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 @@ -279,21 +287,21 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // When you know you are most likely breaking into the middle // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { + private[repl] def echoAndRefresh(msg: String) = { echo("\n" + msg) in.redrawLine() } - protected def echo(msg: String) = { + private[repl] def echo(msg: String) = { out println msg out.flush() } - protected def echoNoNL(msg: String) = { + private def echoNoNL(msg: String) = { out print msg out.flush() } /** Search the history */ - def searchHistory(_cmdline: String) { + private def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 @@ -302,14 +310,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } private var currentPrompt = Properties.shellPromptString + + /** + * Sets the prompt string used by the REPL. + * + * @param prompt The new prompt string + */ + @DeveloperApi def setPrompt(prompt: String) = currentPrompt = prompt - /** Prompt to print when awaiting input */ + + /** + * Represents the current prompt string used by the REPL. + * + * @return The current prompt string + */ + @DeveloperApi def prompt = currentPrompt import LoopCommand.{ cmd, nullary } /** Standard commands */ - lazy val standardCommands = List( + private lazy val standardCommands = List( cmd("cp", "", "add a jar or directory to the classpath", addClasspath), cmd("help", "[command]", "print this summary or command-specific help", helpCommand), historyCommand, @@ -333,7 +354,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, ) /** Power user commands */ - lazy val powerCommands: List[LoopCommand] = List( + private lazy val powerCommands: List[LoopCommand] = List( // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) ) @@ -459,7 +480,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { + private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { override def tryClass(path: String): Array[Byte] = { val hd :: rest = path split '.' toList; // If there are dots in the name, the first segment is the @@ -581,7 +602,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - /** Available commands */ + /** + * Provides a list of available commands. + * + * @return The list of commands + */ + @DeveloperApi def commands: List[LoopCommand] = standardCommands /*++ ( if (isReplPower) powerCommands else Nil )*/ @@ -613,7 +639,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * command() for each line of input, and stops when * command() returns false. */ - def loop() { + private def loop() { def readOneLine() = { out.flush() in readLine prompt @@ -642,7 +668,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { + private def interpretAllFrom(file: File) { savingReader { savingReplayStack { file applyReader { reader => @@ -655,7 +681,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** create a new interpreter and replay the given commands */ - def replay() { + private def replay() { reset() if (replayCommandStack.isEmpty) echo("Nothing to replay.") @@ -665,7 +691,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, echo("") } } - def resetCommand() { + private def resetCommand() { echo("Resetting repl state.") if (replayCommandStack.nonEmpty) { echo("Forgetting this session history:\n") @@ -681,13 +707,13 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, reset() } - def reset() { + private def reset() { intp.reset() // unleashAndSetPhase() } /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "" def apply(line: String): Result = line match { case "" => showUsage() @@ -698,14 +724,14 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def withFile(filename: String)(action: File => Unit) { + private def withFile(filename: String)(action: File => Unit) { val f = File(filename) if (f.exists) action(f) else echo("That file does not exist") } - def loadCommand(arg: String) = { + private def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { interpretAllFrom(f) @@ -714,7 +740,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Result(true, shouldReplay) } - def addAllClasspath(args: Seq[String]): Unit = { + private def addAllClasspath(args: Seq[String]): Unit = { var added = false var totalClasspath = "" for (arg <- args) { @@ -729,7 +755,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def addClasspath(arg: String): Unit = { + private def addClasspath(arg: String): Unit = { val f = File(arg).normalize if (f.exists) { addedClasspath = ClassPath.join(addedClasspath, f.path) @@ -741,12 +767,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } - def powerCmd(): Result = { + private def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode(false) } - def enablePowerMode(isDuringInit: Boolean) = { + private[repl] def enablePowerMode(isDuringInit: Boolean) = { // replProps.power setValue true // unleashAndSetPhase() // asyncEcho(isDuringInit, power.banner) @@ -759,12 +785,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - def asyncEcho(async: Boolean, msg: => String) { + private def asyncEcho(async: Boolean, msg: => String) { if (async) asyncMessage(msg) else echo(msg) } - def verbosity() = { + private def verbosity() = { // val old = intp.printResults // intp.printResults = !old // echo("Switched " + (if (old) "off" else "on") + " result printing.") @@ -773,7 +799,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ - def command(line: String): Result = { + private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) uniqueCommand(cmd) match { @@ -789,7 +815,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - def pasteCommand(): Result = { + private def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" echo("\n// Exiting paste mode, now interpreting.\n") @@ -820,7 +846,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * read, go ahead and interpret it. Return the full string * to be recorded for replay, if any. */ - def interpretStartingWith(code: String): Option[String] = { + private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() @@ -874,7 +900,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { + private def loadFiles(settings: Settings) = settings match { case settings: SparkRunnerSettings => for (filename <- settings.loadfiles.value) { val cmd = ":load " + filename @@ -889,7 +915,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * unless settings or properties are such that it should start * with SimpleReader. */ - def chooseReader(settings: Settings): InteractiveReader = { + private def chooseReader(settings: Settings): InteractiveReader = { if (settings.Xnojline.value || Properties.isEmacsShell) SimpleReader() else try new SparkJLineReader( @@ -903,8 +929,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val m = u.runtimeMirror(Utils.getSparkClassLoader) + private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe + private val m = u.runtimeMirror(Utils.getSparkClassLoader) private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = u.TypeTag[T]( m, @@ -913,7 +939,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] }) - def process(settings: Settings): Boolean = savingContextLoader { + private def process(settings: Settings): Boolean = savingContextLoader { if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") this.settings = settings @@ -972,6 +998,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, true } + // NOTE: Must be public for visibility + @DeveloperApi def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") val jars = SparkILoop.getAddedJars @@ -979,7 +1007,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) - .set("spark.repl.class.uri", intp.classServer.uri) + .set("spark.repl.class.uri", intp.classServerUri) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1014,7 +1042,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) + private def main(settings: Settings): Unit = process(settings) } object SparkILoop { @@ -1033,7 +1061,7 @@ object SparkILoop { // Designed primarily for use by test code: take a String with a // bunch of code, and prints out a transcript of what it would look // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { + private[repl] def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1071,7 +1099,7 @@ object SparkILoop { /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ - def run(code: String, sets: Settings = new Settings): String = { + private[repl] def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1087,5 +1115,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index da4286c5e487..99bd777c04fd 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -19,7 +19,7 @@ import org.apache.spark.SPARK_VERSION /** * Machinery for the asynchronous initialization of the repl. */ -trait SparkILoopInit { +private[repl] trait SparkILoopInit { self: SparkILoop => /** Print a welcome message */ diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index b646f0b6f086..35fb62564502 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -39,6 +39,7 @@ import scala.util.control.ControlThrowable import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils +import org.apache.spark.annotation.DeveloperApi // /** directory to save .class files to */ // private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { @@ -84,17 +85,18 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ + @DeveloperApi class SparkIMain( initialSettings: Settings, val out: JPrintWriter, propagateExceptions: Boolean = false) extends SparkImports with Logging { imain => - val conf = new SparkConf() + private val conf = new SparkConf() - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - lazy val outputDir = { + private lazy val outputDir = { val tmp = System.getProperty("java.io.tmpdir") val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) @@ -103,13 +105,20 @@ import org.apache.spark.util.Utils echo("Output directory: " + outputDir) } - val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles + /** + * Returns the path to the output directory containing all generated + * class files that will be served by the REPL class server. + */ + @DeveloperApi + lazy val getClassOutputDirectory = outputDir + + private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + private val classServerPort = conf.getInt("spark.replClassServer.port", 0) + private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings - var printResults = true // whether to print result lines - var totalSilence = false // whether to print anything + private var printResults = true // whether to print result lines + private var totalSilence = false // whether to print anything private var _initializeComplete = false // compiler is initialized private var _isInitialized: Future[Boolean] = null // set up initialization future private var bindExceptions = true // whether to bind the lastException variable @@ -123,6 +132,14 @@ import org.apache.spark.util.Utils echo("Class server started, URI = " + classServer.uri) } + /** + * URI of the class server used to feed REPL compiled classes. + * + * @return The string representing the class server uri + */ + @DeveloperApi + def classServerUri = classServer.uri + /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -141,17 +158,18 @@ import org.apache.spark.util.Utils () => { counter += 1 ; counter } } - def compilerClasspath: Seq[URL] = ( + private def compilerClasspath: Seq[URL] = ( if (isInitializeComplete) global.classPath.asURLs else new PathResolver(settings).result.asURLs // the compiler's classpath ) - def settings = currentSettings - def mostRecentLine = prevRequestList match { + // NOTE: Exposed to repl package since accessed indirectly from SparkIMain + private[repl] def settings = currentSettings + private def mostRecentLine = prevRequestList match { case Nil => "" case req :: _ => req.originalLine } // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { + private def withoutWarnings[T](body: => T): T = beQuietDuring { val saved = settings.nowarn.value if (!saved) settings.nowarn.value = true @@ -164,16 +182,28 @@ import org.apache.spark.util.Utils def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) def this() = this(new Settings()) - lazy val repllog: Logger = new Logger { + private lazy val repllog: Logger = new Logger { val out: JPrintWriter = imain.out val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" } - lazy val formatting: Formatting = new Formatting { + private[repl] lazy val formatting: Formatting = new Formatting { val prompt = Properties.shellPromptString } - lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop + private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + /** + * Determines if errors were reported (typically during compilation). + * + * @note This is not for runtime errors + * + * @return True if had errors, otherwise false + */ + @DeveloperApi + def isReportingErrors = reporter.hasErrors import formatting._ import reporter.{ printMessage, withoutTruncating } @@ -193,7 +223,8 @@ import org.apache.spark.util.Utils private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def initialize(postInitSignal: => Unit) { synchronized { if (_isInitialized == null) { _isInitialized = io.spawn { @@ -203,15 +234,27 @@ import org.apache.spark.util.Utils } } } + + /** + * Initializes the underlying compiler/interpreter in a blocking fashion. + * + * @note Must be executed before using SparkIMain! + */ + @DeveloperApi def initializeSynchronous(): Unit = { if (!isInitializeComplete) { _initialize() assert(global != null, global) } } - def isInitializeComplete = _initializeComplete + private def isInitializeComplete = _initializeComplete /** the public, go through the future compiler */ + + /** + * The underlying compiler used to generate ASTs and execute code. + */ + @DeveloperApi lazy val global: Global = { if (isInitializeComplete) _compiler else { @@ -226,13 +269,13 @@ import org.apache.spark.util.Utils } } @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - lazy val compiler: global.type = global + private lazy val compiler: global.type = global import global._ import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - implicit class ReplTypeOps(tp: Type) { + private implicit class ReplTypeOps(tp: Type) { def orElse(other: => Type): Type = if (tp ne NoType) tp else other def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) } @@ -240,7 +283,8 @@ import org.apache.spark.util.Utils // TODO: If we try to make naming a lazy val, we run into big time // scalac unhappiness with what look like cycles. It has not been easy to // reduce, but name resolution clearly takes different paths. - object naming extends { + // NOTE: Exposed to repl package since used by SparkExprTyper + private[repl] object naming extends { val global: imain.global.type = imain.global } with Naming { // make sure we don't overwrite their unwisely named res3 etc. @@ -254,22 +298,43 @@ import org.apache.spark.util.Utils } import naming._ - object deconstruct extends { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] object deconstruct extends { val global: imain.global.type = imain.global } with StructuredTypeStrings - lazy val memberHandlers = new { + // NOTE: Exposed to repl package since used by SparkImports + private[repl] lazy val memberHandlers = new { val intp: imain.type = imain } with SparkMemberHandlers import memberHandlers._ - /** Temporarily be quiet */ + /** + * Suppresses overwriting print results during the operation. + * + * @param body The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beQuietDuring[T](body: => T): T = { val saved = printResults printResults = false try body finally printResults = saved } + + /** + * Completely masks all output during the operation (minus JVM standard + * out and error). + * + * @param operation The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beSilentDuring[T](operation: => T): T = { val saved = totalSilence totalSilence = true @@ -277,10 +342,10 @@ import org.apache.spark.util.Utils finally totalSilence = saved } - def quietRun[T](code: String) = beQuietDuring(interpret(code)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { + private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { case t: ControlThrowable => throw t case t: Throwable => logDebug(label + ": " + unwrap(t)) @@ -298,14 +363,44 @@ import org.apache.spark.util.Utils finally bindExceptions = true } + /** + * Contains the code (in string form) representing a wrapper around all + * code executed by this instance. + * + * @return The wrapper code as a string + */ + @DeveloperApi def executionWrapper = _executionWrapper + + /** + * Sets the code to use as a wrapper around all code executed by this + * instance. + * + * @param code The wrapper code as a string + */ + @DeveloperApi def setExecutionWrapper(code: String) = _executionWrapper = code + + /** + * Clears the code used as a wrapper around all code executed by + * this instance. + */ + @DeveloperApi def clearExecutionWrapper() = _executionWrapper = "" /** interpreter settings */ - lazy val isettings = new SparkISettings(this) + private lazy val isettings = new SparkISettings(this) - /** Instantiate a compiler. Overridable. */ + /** + * Instantiates a new compiler used by SparkIMain. Overridable to provide + * own instance of a compiler. + * + * @param settings The settings to provide the compiler + * @param reporter The reporter to use for compiler output + * + * @return The compiler as a Global + */ + @DeveloperApi protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { settings.outputDirs setSingleOutput virtualDirectory settings.exposeEmptyPackage.value = true @@ -320,13 +415,14 @@ import org.apache.spark.util.Utils * @note Currently only supports jars, not directories * @param urls The list of items to add to the compile and runtime classpaths */ + @DeveloperApi def addUrlsToClassPath(urls: URL*): Unit = { new Run // Needed to force initialization of "something" to correctly load Scala classes from jars urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling } - protected def updateCompilerClassPath(urls: URL*): Unit = { + private def updateCompilerClassPath(urls: URL*): Unit = { require(!global.forMSIL) // Only support JavaPlatform val platform = global.platform.asInstanceOf[JavaPlatform] @@ -342,7 +438,7 @@ import org.apache.spark.util.Utils global.invalidateClassPathEntries(urls.map(_.getPath): _*) } - protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { // Collect our new jars/directories and add them to the existing set of classpaths val allClassPaths = ( platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ @@ -365,7 +461,13 @@ import org.apache.spark.util.Utils new MergedClassPath(allClassPaths, platform.classPath.context) } - /** Parent classloader. Overridable. */ + /** + * Represents the parent classloader used by this instance. Can be + * overridden to provide alternative classloader. + * + * @return The classloader used as the parent loader of this instance + */ + @DeveloperApi protected def parentClassLoader: ClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) @@ -382,16 +484,18 @@ import org.apache.spark.util.Utils shadow the old ones, and old code objects refer to the old definitions. */ - def resetClassLoader() = { + private def resetClassLoader() = { logDebug("Setting new classloader: was " + _classLoader) _classLoader = null ensureClassLoader() } - final def ensureClassLoader() { + private final def ensureClassLoader() { if (_classLoader == null) _classLoader = makeClassLoader() } - def classLoader: AbstractFileClassLoader = { + + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def classLoader: AbstractFileClassLoader = { ensureClassLoader() _classLoader } @@ -418,27 +522,58 @@ import org.apache.spark.util.Utils _runtimeClassLoader }) - def getInterpreterClassLoader() = classLoader + private def getInterpreterClassLoader() = classLoader // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() + // NOTE: Exposed to repl package since used by SparkILoopInit + private[repl] def setContextClassLoader() = classLoader.setAsContext() - /** Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} + /** + * Returns the real name of a class based on its repl-defined name. + * + * ==Example== + * Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * {{{ + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + * }}} + * + * @param simpleName The repl-defined name whose real name to retrieve + * + * @return Some real name if the simple name exists, else None */ + @DeveloperApi def generatedName(simpleName: String): Option[String] = { if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) else optFlatName(simpleName) } - def flatName(id: String) = optFlatName(id) getOrElse id - def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def flatName(id: String) = optFlatName(id) getOrElse id + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + /** + * Retrieves all simple names contained in the current instance. + * + * @return A list of sorted names + */ + @DeveloperApi def allDefinedNames = definedNameMap.keys.toList.sorted - def pathToType(id: String): String = pathToName(newTypeName(id)) - def pathToTerm(id: String): String = pathToName(newTermName(id)) + + private def pathToType(id: String): String = pathToName(newTypeName(id)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) + + /** + * Retrieves the full code path to access the specified simple name + * content. + * + * @param name The simple name of the target whose path to determine + * + * @return The full path used to access the specified target (name) + */ + @DeveloperApi def pathToName(name: Name): String = { if (definedNameMap contains name) definedNameMap(name) fullPath name @@ -457,13 +592,13 @@ import org.apache.spark.util.Utils } /** Stubs for work in progress. */ - def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) } } - def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { // Printing the types here has a tendency to cause assertion errors, like // assertion failed: fatal: has owner value x, but a class owner is required @@ -473,7 +608,7 @@ import org.apache.spark.util.Utils } } - def recordRequest(req: Request) { + private def recordRequest(req: Request) { if (req == null || referencedNameMap == null) return @@ -504,12 +639,12 @@ import org.apache.spark.util.Utils } } - def replwarn(msg: => String) { + private def replwarn(msg: => String) { if (!settings.nowarnings.value) printMessage(msg) } - def isParseable(line: String): Boolean = { + private def isParseable(line: String): Boolean = { beSilentDuring { try parse(line) match { case Some(xs) => xs.nonEmpty // parses as-is @@ -522,22 +657,32 @@ import org.apache.spark.util.Utils } } - def compileSourcesKeepingRun(sources: SourceFile*) = { + private def compileSourcesKeepingRun(sources: SourceFile*) = { val run = new Run() reporter.reset() run compileSources sources.toList (!reporter.hasErrors, run) } - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. + /** + * Compiles specified source files. + * + * @param sources The sequence of source files to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileSources(sources: SourceFile*): Boolean = compileSourcesKeepingRun(sources: _*)._1 - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. + /** + * Compiles a string of code. + * + * @param code The string of code to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileString(code: String): Boolean = compileSources(new BatchSourceFile("
    Property NameDefaultMeaning
    spark.yarn.am.memory512m + Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). + In cluster mode, use spark.driver.memory instead. +
    spark.yarn.am.waitTime 100000spark.yarn.driver.memoryOverhead driverMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). +
    spark.yarn.am.memoryOverheadAM memory * 0.07, with minimum of 384 + Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode.
    spark.yarn.am.extraJavaOptions (none) - A string of extra JVM options to pass to the Yarn ApplicationMaster in client mode. + A string of extra JVM options to pass to the YARN Application Master in client mode. In cluster mode, use spark.driver.extraJavaOptions instead.
    spark.storage.memoryMapThreshold81922097152 Size of a block, in bytes, above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory From 6942b974adad396cba2799eac1fa90448cea4da7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 11 Jan 2015 16:23:28 -0800 Subject: [PATCH 120/215] [SPARK-4951][Core] Fix the issue that a busy executor may be killed A few changes to fix this issue: 1. Handle the case that receiving `SparkListenerTaskStart` before `SparkListenerBlockManagerAdded`. 2. Don't add `executorId` to `removeTimes` when the executor is busy. 3. Use `HashMap.retain` to safely traverse the HashMap and remove items. 4. Use the same lock in ExecutorAllocationManager and ExecutorAllocationListener to fix the race condition in `totalPendingTasks`. 5. Move the blocking codes out of the message processing code in YarnSchedulerActor. Author: zsxwing Closes #3783 from zsxwing/SPARK-4951 and squashes the following commits: d51fa0d [zsxwing] Add comments 2e365ce [zsxwing] Remove expired executors from 'removeTimes' and add idle executors back when a new executor joins 49f61a9 [zsxwing] Eliminate duplicate executor registered warnings d4c4e9a [zsxwing] Minor fixes for the code style 05f6238 [zsxwing] Move the blocking codes out of the message processing code 105ba3a [zsxwing] Fix the race condition in totalPendingTasks d5c615d [zsxwing] Fix the issue that a busy executor may be killed --- .../spark/ExecutorAllocationManager.scala | 117 ++++++++++++------ .../cluster/YarnSchedulerBackend.scala | 23 +++- .../ExecutorAllocationManagerSuite.scala | 49 +++++++- 3 files changed, 144 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index e9e90e3f2f65..a0ee2a7cbb2a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -65,6 +65,9 @@ private[spark] class ExecutorAllocationManager( listenerBus: LiveListenerBus, conf: SparkConf) extends Logging { + + allocationManager => + import ExecutorAllocationManager._ // Lower and upper bounds on the number of executors. These are required. @@ -121,7 +124,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new RealClock // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener(this) + private val listener = new ExecutorAllocationListener /** * Verify that the settings specified through the config are valid. @@ -209,11 +212,12 @@ private[spark] class ExecutorAllocationManager( addTime += sustainedSchedulerBacklogTimeout * 1000 } - removeTimes.foreach { case (executorId, expireTime) => - if (now >= expireTime) { + removeTimes.retain { case (executorId, expireTime) => + val expired = now >= expireTime + if (expired) { removeExecutor(executorId) - removeTimes.remove(executorId) } + !expired } } @@ -291,7 +295,7 @@ private[spark] class ExecutorAllocationManager( // Do not kill the executor if we have already reached the lower bound val numExistingExecutors = executorIds.size - executorsPendingToRemove.size if (numExistingExecutors - 1 < minNumExecutors) { - logInfo(s"Not removing idle executor $executorId because there are only " + + logDebug(s"Not removing idle executor $executorId because there are only " + s"$numExistingExecutors executor(s) left (limit $minNumExecutors)") return false } @@ -315,7 +319,11 @@ private[spark] class ExecutorAllocationManager( private def onExecutorAdded(executorId: String): Unit = synchronized { if (!executorIds.contains(executorId)) { executorIds.add(executorId) - executorIds.foreach(onExecutorIdle) + // If an executor (call this executor X) is not removed because the lower bound + // has been reached, it will no longer be marked as idle. When new executors join, + // however, we are no longer at the lower bound, and so we must mark executor X + // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951) + executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle) logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") if (numExecutorsPending > 0) { numExecutorsPending -= 1 @@ -373,10 +381,14 @@ private[spark] class ExecutorAllocationManager( * the executor is not already marked as idle. */ private def onExecutorIdle(executorId: String): Unit = synchronized { - if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { - logDebug(s"Starting idle timer for $executorId because there are no more tasks " + - s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") - removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + if (executorIds.contains(executorId)) { + if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + logDebug(s"Starting idle timer for $executorId because there are no more tasks " + + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + } + } else { + logWarning(s"Attempted to mark unknown executor $executorId idle") } } @@ -396,25 +408,24 @@ private[spark] class ExecutorAllocationManager( * and consistency of events returned by the listener. For simplicity, it does not account * for speculated tasks. */ - private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager) - extends SparkListener { + private class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { - synchronized { - val stageId = stageSubmitted.stageInfo.stageId - val numTasks = stageSubmitted.stageInfo.numTasks + val stageId = stageSubmitted.stageInfo.stageId + val numTasks = stageSubmitted.stageInfo.numTasks + allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { - synchronized { - val stageId = stageCompleted.stageInfo.stageId + val stageId = stageCompleted.stageInfo.stageId + allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId @@ -426,39 +437,49 @@ private[spark] class ExecutorAllocationManager( } } - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { val stageId = taskStart.stageId val taskId = taskStart.taskInfo.taskId val taskIndex = taskStart.taskInfo.index val executorId = taskStart.taskInfo.executorId - // If this is the last pending task, mark the scheduler queue as empty - stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() + allocationManager.synchronized { + // This guards against the race condition in which the `SparkListenerTaskStart` + // event is posted before the `SparkListenerBlockManagerAdded` event, which is + // possible because these events are posted in different threads. (see SPARK-4951) + if (!allocationManager.executorIds.contains(executorId)) { + allocationManager.onExecutorAdded(executorId) + } + + // If this is the last pending task, mark the scheduler queue as empty + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + val numTasksScheduled = stageIdToTaskIndices(stageId).size + val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) + if (numTasksScheduled == numTasksTotal) { + // No more pending tasks for this stage + stageIdToNumTasks -= stageId + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } } - } - // Mark the executor on which this task is scheduled as busy - executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId - allocationManager.onExecutorBusy(executorId) + // Mark the executor on which this task is scheduled as busy + executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId + allocationManager.onExecutorBusy(executorId) + } } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId - - // If the executor is no longer running scheduled any tasks, mark it as idle - if (executorIdToTaskIds.contains(executorId)) { - executorIdToTaskIds(executorId) -= taskId - if (executorIdToTaskIds(executorId).isEmpty) { - executorIdToTaskIds -= executorId - allocationManager.onExecutorIdle(executorId) + allocationManager.synchronized { + // If the executor is no longer running scheduled any tasks, mark it as idle + if (executorIdToTaskIds.contains(executorId)) { + executorIdToTaskIds(executorId) -= taskId + if (executorIdToTaskIds(executorId).isEmpty) { + executorIdToTaskIds -= executorId + allocationManager.onExecutorIdle(executorId) + } } } } @@ -466,7 +487,12 @@ private[spark] class ExecutorAllocationManager( override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { val executorId = blockManagerAdded.blockManagerId.executorId if (executorId != SparkContext.DRIVER_IDENTIFIER) { - allocationManager.onExecutorAdded(executorId) + // This guards against the race condition in which the `SparkListenerTaskStart` + // event is posted before the `SparkListenerBlockManagerAdded` event, which is + // possible because these events are posted in different threads. (see SPARK-4951) + if (!allocationManager.executorIds.contains(executorId)) { + allocationManager.onExecutorAdded(executorId) + } } } @@ -478,12 +504,23 @@ private[spark] class ExecutorAllocationManager( /** * An estimate of the total number of pending tasks remaining for currently running stages. Does * not account for tasks which may have failed and been resubmitted. + * + * Note: This is not thread-safe without the caller owning the `allocationManager` lock. */ def totalPendingTasks(): Int = { stageIdToNumTasks.map { case (stageId, numTasks) => numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) }.sum } + + /** + * Return true if an executor is not currently running a task, and false otherwise. + * + * Note: This is not thread-safe without the caller owning the `allocationManager` lock. + */ + def isExecutorIdle(executorId: String): Boolean = { + !executorIdToTaskIds.contains(executorId) + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 50721b9d6cd6..f14aaeea0a25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster +import scala.concurrent.{Future, ExecutionContext} + import akka.actor.{Actor, ActorRef, Props} import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} @@ -24,7 +26,9 @@ import org.apache.spark.SparkContext import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.{AkkaUtils, Utils} + +import scala.util.control.NonFatal /** * Abstract Yarn scheduler backend that contains common logic @@ -97,6 +101,9 @@ private[spark] abstract class YarnSchedulerBackend( private class YarnSchedulerActor extends Actor { private var amActor: Option[ActorRef] = None + implicit val askAmActorExecutor = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) + override def preStart(): Unit = { // Listen for disassociation events context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -110,7 +117,12 @@ private[spark] abstract class YarnSchedulerBackend( case r: RequestExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + val driverActor = sender + Future { + driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + } onFailure { + case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) + } case None => logWarning("Attempted to request executors before the AM has registered!") sender ! false @@ -119,7 +131,12 @@ private[spark] abstract class YarnSchedulerBackend( case k: KillExecutors => amActor match { case Some(actor) => - sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + val driverActor = sender + Future { + driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + } onFailure { + case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) + } case None => logWarning("Attempted to kill executors before the AM has registered!") sender ! false diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index c817f6dcede7..0e4df17c1bf8 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.collection.mutable + import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -143,11 +145,17 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { // Verify that running a task reduces the cap sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(numExecutorsPending(manager) === 4) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsPending(manager) === 5) assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 0) assert(numExecutorsPending(manager) === 7) assert(numExecutorsToAdd(manager) === 1) @@ -325,6 +333,8 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val manager = sc.executorAllocationManager.get manager.setClock(clock) + executorIds(manager).asInstanceOf[mutable.Set[String]] ++= List("1", "2", "3") + // Starting remove timer is idempotent for each executor assert(removeTimes(manager).isEmpty) onExecutorIdle(manager, "1") @@ -597,6 +607,41 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(removeTimes(manager).size === 1) } + test("SPARK-4951: call onTaskStart before onBlockManagerAdded") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(executorIds(manager).isEmpty) + assert(removeTimes(manager).isEmpty) + + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + assert(executorIds(manager).size === 1) + assert(executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 0) + } + + test("SPARK-4951: onExecutorAdded should not add a busy executor to removeTimes") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(executorIds(manager).isEmpty) + assert(removeTimes(manager).isEmpty) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + + assert(executorIds(manager).size === 1) + assert(executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 0) + + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-2", "host1", 1), 100L)) + assert(executorIds(manager).size === 2) + assert(executorIds(manager).contains("executor-2")) + assert(removeTimes(manager).size === 1) + assert(removeTimes(manager).contains("executor-2")) + assert(!removeTimes(manager).contains("executor-1")) + } } /** From f38ef6586c2980183c983b2aa14a5ddc1856b7b7 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Sun, 11 Jan 2015 16:32:47 -0800 Subject: [PATCH 121/215] [SPARK-4033][Examples]Input of the SparkPi too big causes the emption exception If input of the SparkPi args is larger than the 25000, the integer 'n' inside the code will be overflow, and may be a negative number. And it causes the (0 until n) Seq as an empty seq, then doing the action 'reduce' will throw the UnsupportedOperationException("empty collection"). The max size of the input of sc.parallelize is Int.MaxValue - 1, not the Int.MaxValue. Author: huangzhaowei Closes #2874 from SaintBacchus/SparkPi and squashes the following commits: 62d7cd7 [huangzhaowei] Add a commit to explain the modify 4cdc388 [huangzhaowei] Update SparkPi.scala 9a2fb7b [huangzhaowei] Input of the SparkPi is too big --- .../src/main/scala/org/apache/spark/examples/SparkPi.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 9fbb0a800d73..35b8dd6c29b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -27,8 +27,8 @@ object SparkPi { val conf = new SparkConf().setAppName("Spark Pi") val spark = new SparkContext(conf) val slices = if (args.length > 0) args(0).toInt else 2 - val n = 100000 * slices - val count = spark.parallelize(1 to n, slices).map { i => + val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow + val count = spark.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 From 2130de9d8f50f52b9b2d296b377df81d840546b3 Mon Sep 17 00:00:00 2001 From: Travis Galoppo Date: Sun, 11 Jan 2015 21:31:16 -0800 Subject: [PATCH 122/215] SPARK-5018 [MLlib] [WIP] Make MultivariateGaussian public Moving MutlivariateGaussian from private[mllib] to public. The class uses Breeze vectors internally, so this involves creating a public interface using MLlib vectors and matrices. This initial commit provides public construction, accessors for mean/covariance, density and log-density. Other potential methods include entropy and sample generation. Author: Travis Galoppo Closes #3923 from tgaloppo/spark-5018 and squashes the following commits: 2b15587 [Travis Galoppo] Style correction b4121b4 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' into spark-5018 e30a100 [Travis Galoppo] Made mu, sigma private[mllib] members of MultivariateGaussian Moved MultivariateGaussian (and test suite) from stat.impl to stat.distribution (required updates in GaussianMixture{EM,Model}.scala) Marked MultivariateGaussian as @DeveloperApi Fixed style error 9fa3bb7 [Travis Galoppo] Style improvements 91a5fae [Travis Galoppo] Rearranged equation for part of density function 8c35381 [Travis Galoppo] Fixed accessor methods to match member variable names. Modified calculations to avoid log(pow(x,y)) calculations 0943dc4 [Travis Galoppo] SPARK-5018 4dee9e1 [Travis Galoppo] SPARK-5018 --- .../mllib/clustering/GaussianMixtureEM.scala | 11 +-- .../clustering/GaussianMixtureModel.scala | 2 +- .../MultivariateGaussian.scala | 70 ++++++++++++++----- .../MultivariateGaussianSuite.scala | 33 +++++---- 4 files changed, 75 insertions(+), 41 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/stat/{impl => distribution}/MultivariateGaussian.scala (61%) rename mllib/src/test/scala/org/apache/spark/mllib/stat/{impl => distribution}/MultivariateGaussianSuite.scala (72%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index b3c5631cc4cc..d8e134619411 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -20,10 +20,11 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.IndexedSeq import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose} -import org.apache.spark.rdd.RDD + import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} -import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils /** @@ -134,7 +135,7 @@ class GaussianMixtureEM private ( // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) => - new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix) + new MultivariateGaussian(mu, sigma) }) case None => { @@ -176,8 +177,8 @@ class GaussianMixtureEM private ( } // Need to convert the breeze matrices to MLlib matrices - val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) } - val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) } + val means = Array.tabulate(k) { i => gaussians(i).mu } + val sigmas = Array.tabulate(k) { i => gaussians(i).sigma } new GaussianMixtureModel(weights, means, sigmas) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index b461ea4f0f06..416cad080c40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -21,7 +21,7 @@ import breeze.linalg.{DenseVector => BreezeVector} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, Vector} -import org.apache.spark.mllib.stat.impl.MultivariateGaussian +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala similarity index 61% rename from mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala rename to mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index bc7f6c5197ac..fd186b5ee6f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -15,13 +15,16 @@ * limitations under the License. */ -package org.apache.spark.mllib.stat.impl +package org.apache.spark.mllib.stat.distribution import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym} +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils /** + * :: DeveloperApi :: * This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. @@ -30,33 +33,64 @@ import org.apache.spark.mllib.util.MLUtils * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ -private[mllib] class MultivariateGaussian( - val mu: DBV[Double], - val sigma: DBM[Double]) extends Serializable { +@DeveloperApi +class MultivariateGaussian ( + val mu: Vector, + val sigma: Matrix) extends Serializable { + require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") + require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") + + private val breezeMu = mu.toBreeze.toDenseVector + + /** + * private[mllib] constructor + * + * @param mu The mean vector of the distribution + * @param sigma The covariance matrix of the distribution + */ + private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { + this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) + } + /** * Compute distribution dependent constants: - * rootSigmaInv = D^(-1/2) * U, where sigma = U * D * U.t - * u = (2*pi)^(-k/2) * det(sigma)^(-1/2) + * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x */ - def pdf(x: DBV[Double]): Double = { - val delta = x - mu + def pdf(x: Vector): Double = { + pdf(x.toBreeze.toDenseVector) + } + + /** Returns the log-density of this multivariate Gaussian at given point, x */ + def logpdf(x: Vector): Double = { + logpdf(x.toBreeze.toDenseVector) + } + + /** Returns density of this multivariate Gaussian at given point, x */ + private[mllib] def pdf(x: DBV[Double]): Double = { + math.exp(logpdf(x)) + } + + /** Returns the log-density of this multivariate Gaussian at given point, x */ + private[mllib] def logpdf(x: DBV[Double]): Double = { + val delta = x - breezeMu val v = rootSigmaInv * delta - u * math.exp(v.t * v * -0.5) + u + v.t * v * -0.5 } /** * Calculate distribution dependent components used for the density function: - * pdf(x) = (2*pi)^(-k/2) * det(sigma)^(-1/2) * exp( (-1/2) * (x-mu).t * inv(sigma) * (x-mu) ) + * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. * * We here compute distribution-fixed parts - * (2*pi)^(-k/2) * det(sigma)^(-1/2) + * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and - * D^(-1/2) * U, where sigma = U * D * U.t + * D^(-1/2)^ * U, where sigma = U * D * U.t * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, @@ -65,11 +99,11 @@ private[mllib] class MultivariateGaussian( * * sigma = U * D * U.t * inv(Sigma) = U * inv(D) * U.t - * = (D^{-1/2} * U).t * (D^{-1/2} * U) + * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) * * and thus * - * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2} * U * (x-mu))^2 + * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ * * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered @@ -77,21 +111,21 @@ private[mllib] class MultivariateGaussian( * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { - val eigSym.EigSym(d, u) = eigSym(sigma) // sigma = u * diag(d) * u.t + val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length try { - // pseudo-determinant is product of all non-zero singular values - val pdetSigma = d.activeValuesIterator.filter(_ > tol).reduce(_ * _) + // log(pseudo-determinant) is sum of the logs of all non-zero singular values + val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - (pinvS * u, math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(pdetSigma, -0.5)) + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => throw new IllegalArgumentException("Covariance matrix has no non-zero singular values") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala similarity index 72% rename from mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index d58f2587e55a..fac2498e4dcb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -15,54 +15,53 @@ * limitations under the License. */ -package org.apache.spark.mllib.stat.impl +package org.apache.spark.mllib.stat.distribution import org.scalatest.FunSuite -import breeze.linalg.{ DenseVector => BDV, DenseMatrix => BDM } - +import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { test("univariate") { - val x1 = new BDV(Array(0.0)) - val x2 = new BDV(Array(1.5)) + val x1 = Vectors.dense(0.0) + val x2 = Vectors.dense(1.5) - val mu = new BDV(Array(0.0)) - val sigma1 = new BDM(1, 1, Array(1.0)) + val mu = Vectors.dense(0.0) + val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - val sigma2 = new BDM(1, 1, Array(4.0)) + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } test("multivariate") { - val x1 = new BDV(Array(0.0, 0.0)) - val x2 = new BDV(Array(1.0, 1.0)) + val x1 = Vectors.dense(0.0, 0.0) + val x2 = Vectors.dense(1.0, 1.0) - val mu = new BDV(Array(0.0, 0.0)) - val sigma1 = new BDM(2, 2, Array(1.0, 0.0, 0.0, 1.0)) + val mu = Vectors.dense(0.0, 0.0) + val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - val sigma2 = new BDM(2, 2, Array(4.0, -1.0, -1.0, 2.0)) + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } test("multivariate degenerate") { - val x1 = new BDV(Array(0.0, 0.0)) - val x2 = new BDV(Array(1.0, 1.0)) + val x1 = Vectors.dense(0.0, 0.0) + val x2 = Vectors.dense(1.0, 1.0) - val mu = new BDV(Array(0.0, 0.0)) - val sigma = new BDM(2, 2, Array(1.0, 1.0, 1.0, 1.0)) + val mu = Vectors.dense(0.0, 0.0) + val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) assert(dist.pdf(x1) ~== 0.11254 absTol 1E-5) assert(dist.pdf(x2) ~== 0.068259 absTol 1E-5) From 82fd38dcdcc9f7df18930c0e08cc8ec34eaee828 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 12 Jan 2015 10:47:12 -0800 Subject: [PATCH 123/215] [SPARK-5200] Disable web UI in Hive ThriftServer tests Disables the Spark web UI in HiveThriftServer2Suite in order to prevent Jenkins test failures due to port contention. Author: Josh Rosen Closes #3998 from JoshRosen/SPARK-5200 and squashes the following commits: a384416 [Josh Rosen] [SPARK-5200] Disable web UI in Hive Thriftserver tests. --- .../spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 7814aa38f414..b52a51d11e4a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -143,6 +143,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } else { s"""$startScript @@ -153,6 +154,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } From ef9224e08010420b570c21a0b9208d22792a24fe Mon Sep 17 00:00:00 2001 From: lianhuiwang Date: Mon, 12 Jan 2015 10:57:12 -0800 Subject: [PATCH 124/215] [SPARK-5102][Core]subclass of MapStatus needs to be registered with Kryo CompressedMapStatus and HighlyCompressedMapStatus needs to be registered with Kryo, because they are subclass of MapStatus. Author: lianhuiwang Closes #4007 from lianhuiwang/SPARK-5102 and squashes the following commits: 9d2238a [lianhuiwang] remove register of MapStatus 05a285d [lianhuiwang] subclass of MapStatus needs to be registered with Kryo --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d2947dcea4f7..d56e23ce4478 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer @@ -207,7 +207,8 @@ private[serializer] object KryoSerializer { classOf[PutBlock], classOf[GotBlock], classOf[GetBlock], - classOf[MapStatus], + classOf[CompressedMapStatus], + classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], From 13e610b88eb83b11785737ab1d99927f676f81c6 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 12 Jan 2015 11:00:56 -0800 Subject: [PATCH 125/215] SPARK-4159 [BUILD] Addendum: improve running of single test after enabling Java tests https://issues.apache.org/jira/browse/SPARK-4159 was resolved but as Sandy points out, the guidance in https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools under "Running Individual Tests" no longer quite works, not optimally. This minor change is not really the important change, which is an update to the wiki text. The correct way to run one Scala test suite in Maven is now: ``` mvn test -DwildcardSuites=org.apache.spark.io.CompressionCodecSuite -Dtests=none ``` The correct way to run one Java test is ``` mvn test -DwildcardSuites=none -Dtests=org.apache.spark.streaming.JavaAPISuite ``` Basically, you have to set two properties in order to suppress all of one type of test (with a non-existent test name like 'none') and all but one test of the other type. The change in the PR just prevents Surefire from barfing when it finds no "none" test. Author: Sean Owen Closes #3993 from srowen/SPARK-4159 and squashes the following commits: 83106d7 [Sean Owen] Default failIfNoTests to false to enable the -DwildcardSuites=... -Dtests=... syntax for running one test to work --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index aadcdfd1083c..0403c008be4f 100644 --- a/pom.xml +++ b/pom.xml @@ -1130,6 +1130,7 @@ ${test_classpath} true + false From a3978f3e156e0ca67e978f1795b238ddd69ff9a6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 12 Jan 2015 11:57:59 -0800 Subject: [PATCH 126/215] [SPARK-5078] Optionally read from SPARK_LOCAL_HOSTNAME Current spark lets you set the ip address using SPARK_LOCAL_IP, but then this is given to akka after doing a reverse DNS lookup. This makes it difficult to run spark in Docker. You can already change the hostname that is used programmatically, but it would be nice to be able to do this with an environment variable as well. Author: Michael Armbrust Closes #3893 from marmbrus/localHostnameEnv and squashes the following commits: 85045b6 [Michael Armbrust] Optionally read from SPARK_LOCAL_HOSTNAME --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4f1898a2db1..578f1a9cf486 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -701,7 +701,7 @@ private[spark] object Utils extends Logging { } } - private var customHostname: Option[String] = None + private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME") /** * Allow setting a custom host name because when we run on Mesos we need to use the same From aff49a3ee1a91d730590f4a0b9ac485fd52dc8bf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 12 Jan 2015 12:15:34 -0800 Subject: [PATCH 127/215] SPARK-5172 [BUILD] spark-examples-***.jar shades a wrong Hadoop distribution In addition to the `hadoop-2.x` profiles in the parent POM, there is actually another set of profiles in `examples` that has to be activated differently to get the right Hadoop 1 vs 2 flavor of HBase. This wasn't actually used in making Hadoop 2 distributions, hence the problem. To reduce complexity, I suggest merging them with the parent POM profiles, which is possible now. You'll see this changes appears to update the HBase version, but actually, the default 0.94 version was not being used. HBase is only used in examples, and the examples POM always chose one profile or the other that updated the version to 0.98.x anyway. Author: Sean Owen Closes #3992 from srowen/SPARK-5172 and squashes the following commits: 17830d9 [Sean Owen] Control hbase hadoop1/2 flavor in the parent POM with existing hadoop-2.x profiles --- examples/pom.xml | 23 ----------------------- pom.xml | 5 ++++- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/examples/pom.xml b/examples/pom.xml index 002d4458c4b3..4b92147725f6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -392,29 +392,6 @@ - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - diff --git a/pom.xml b/pom.xml index 0403c008be4f..f4466e56c2a5 100644 --- a/pom.xml +++ b/pom.xml @@ -122,7 +122,7 @@ 1.0.4 2.4.1 ${hadoop.version} - 0.94.6 + 0.98.7-hadoop1 hbase 1.4.0 3.4.5 @@ -1466,6 +1466,7 @@ 2.2.0 2.5.0 + 0.98.7-hadoop2 hadoop2 @@ -1476,6 +1477,7 @@ 2.3.0 2.5.0 0.9.0 + 0.98.7-hadoop2 3.1.1 hadoop2 @@ -1487,6 +1489,7 @@ 2.4.0 2.5.0 0.9.0 + 0.98.7-hadoop2 3.1.1 hadoop2 From 3aed3051c0b6cd5f38d7db7d20fa7a1680bfde6f Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 12 Jan 2015 13:14:44 -0800 Subject: [PATCH 128/215] [SPARK-4999][Streaming] Change storeInBlockManager to false by default Currently WAL-backed block is read out from HDFS and put into BlockManger with storage level MEMORY_ONLY_SER by default, since WAL-backed block is already materialized in HDFS with fault-tolerance, no need to put into BlockManger again by default. Author: jerryshao Closes #3906 from jerryshao/SPARK-4999 and squashes the following commits: b95f95e [jerryshao] Change storeInBlockManager to false by default --- .../apache/spark/streaming/dstream/ReceiverInputDStream.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index c834744631e0..afd3c4bc4c4f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -86,7 +86,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont }.toArray // Since storeInBlockManager = false, the storage level does not matter. new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, - blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER) + blockIds, logSegments, storeInBlockManager = false, StorageLevel.MEMORY_ONLY_SER) } else { new BlockRDD[T](ssc.sc, blockIds) } From 5d9fa550820543ee1b0ce82997917745973a5d65 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 12 Jan 2015 15:19:09 -0800 Subject: [PATCH 129/215] [SPARK-5049][SQL] Fix ordering of partition columns in ParquetTableScan Followup to #3870. Props to rahulaggarwalguavus for identifying the issue. Author: Michael Armbrust Closes #3990 from marmbrus/SPARK-5049 and squashes the following commits: dd03e4e [Michael Armbrust] Fill in the partition values of parquet scans instead of using JoinedRow --- .../spark/sql/parquet/ParquetRelation.scala | 4 +- .../sql/parquet/ParquetTableOperations.scala | 43 +++++++++++-------- .../spark/sql/parquet/parquetSuites.scala | 12 ++++++ 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b237a07c72d0..2835dc3408b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -28,7 +28,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} /** @@ -67,6 +67,8 @@ private[sql] case class ParquetRelation( conf, sqlContext.isParquetBinaryAsString) + lazy val attributeMap = AttributeMap(output.map(o => o -> o)) + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 96bace1769f7..f5487740d3af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -64,18 +64,17 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val normalOutput = - attributes - .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) - .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + val output = attributes.map(relation.attributeMap) - val partOutput = - attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + // A mapping of ordinals partitionRow -> finalOutput. + val requestedPartitionOrdinals = { + val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - def output = partOutput ++ normalOutput - - assert(normalOutput.size + partOutput.size == attributes.size, - s"$normalOutput + $partOutput != $attributes, ${relation.output}") + attributes.zipWithIndex.flatMap { + case (attribute, finalOrdinal) => + partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) + } + }.toArray override def execute(): RDD[Row] = { import parquet.filter2.compat.FilterCompat.FilterPredicateCompat @@ -97,7 +96,7 @@ case class ParquetTableScan( // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(normalOutput)) + ParquetTypesConverter.convertToString(output)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -125,7 +124,7 @@ case class ParquetTableScan( classOf[Row], conf) - if (partOutput.nonEmpty) { + if (requestedPartitionOrdinals.nonEmpty) { baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = @@ -138,15 +137,25 @@ case class ParquetTableScan( case _ => None }.toMap + // Convert the partitioning attributes into the correct types val partitionRowValues = - partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + relation.partitioningAttributes + .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) new Iterator[Row] { - private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null) - def hasNext = iter.hasNext - - def next() = joinedRow.withRight(iter.next()._2) + def next() = { + val row = iter.next()._2.asInstanceOf[SpecificMutableRow] + + // Parquet will leave partitioning columns empty, so we fill them in here. + var i = 0 + while (i < requestedPartitionOrdinals.size) { + row(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + row + } } } } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index fc0e42c201d5..8bbb7f2fdbf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -174,6 +174,18 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)((1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(("part-1", 1)) + ) + } + test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"), From 1e42e96ece9e35ceed9ddebef66d589016878b56 Mon Sep 17 00:00:00 2001 From: Gabe Mulley Date: Mon, 12 Jan 2015 21:44:51 -0800 Subject: [PATCH 130/215] [SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple When attempting to infer the schema of an RDD that contains namedtuples, pyspark fails to identify the records as namedtuples, resulting in it raising an error. Example: ```python from pyspark import SparkContext from pyspark.sql import SQLContext from collections import namedtuple import os sc = SparkContext() rdd = sc.textFile(os.path.join(os.getenv('SPARK_HOME'), 'README.md')) TextLine = namedtuple('TextLine', 'line length') tuple_rdd = rdd.map(lambda l: TextLine(line=l, length=len(l))) tuple_rdd.take(5) # This works sqlc = SQLContext(sc) # The following line raises an error schema_rdd = sqlc.inferSchema(tuple_rdd) ``` The error raised is: ``` File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 107, in main process() File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 98, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/serializers.py", line 227, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/rdd.py", line 1107, in takeUpToNumLeft yield next(iterator) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/sql.py", line 816, in convert_struct raise ValueError("unexpected tuple: %s" % obj) TypeError: not all arguments converted during string formatting ``` Author: Gabe Mulley Closes #3978 from mulby/inferschema-namedtuple and squashes the following commits: 98c61cc [Gabe Mulley] Ensure exception message is populated correctly 375d96b [Gabe Mulley] Ensure schema can be inferred from a namedtuple --- python/pyspark/sql.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0e8b398fc6b9..014ac1791c84 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -807,14 +807,14 @@ def convert_struct(obj): return if isinstance(obj, tuple): - if hasattr(obj, "fields"): - d = dict(zip(obj.fields, obj)) - if hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): d = dict(zip(obj.__FIELDS__, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): d = dict(obj) else: - raise ValueError("unexpected tuple: %s" % obj) + raise ValueError("unexpected tuple: %s" % str(obj)) elif isinstance(obj, dict): d = obj @@ -1327,6 +1327,16 @@ def inferSchema(self, rdd, samplingRatio=None): >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') """ if isinstance(rdd, SchemaRDD): From f7741a9a72fef23c46f0ad9e1bd16b150967d816 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Tue, 13 Jan 2015 09:28:21 -0800 Subject: [PATCH 131/215] [SPARK-5006][Deploy]spark.port.maxRetries doesn't work https://issues.apache.org/jira/browse/SPARK-5006 I think the issue is produced in https://github.com/apache/spark/pull/1777. Not digging mesos's backend yet. Maybe should add same logic either. Author: WangTaoTheTonic Author: WangTao Closes #3841 from WangTaoTheTonic/SPARK-5006 and squashes the following commits: 8cdf96d [WangTao] indent thing 2d86d65 [WangTaoTheTonic] fix line length 7cdfd98 [WangTaoTheTonic] fit for new HttpServer constructor 61a370d [WangTaoTheTonic] some minor fixes bc6e1ec [WangTaoTheTonic] rebase 67bcb46 [WangTaoTheTonic] put conf at 3rd position, modify suite class, add comments f450cd1 [WangTaoTheTonic] startServiceOnPort will use a SparkConf arg 29b751b [WangTaoTheTonic] rebase as ExecutorRunnableUtil changed to ExecutorRunnable 396c226 [WangTaoTheTonic] make the grammar more like scala 191face [WangTaoTheTonic] invalid value name 62ec336 [WangTaoTheTonic] spark.port.maxRetries doesn't work --- .../org/apache/spark/HttpFileServer.scala | 3 ++- .../scala/org/apache/spark/HttpServer.scala | 3 ++- .../scala/org/apache/spark/SparkConf.scala | 6 +++-- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 3 ++- .../spark/network/nio/ConnectionManager.scala | 2 +- .../org/apache/spark/ui/JettyUtils.scala | 2 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 23 +++++++++---------- .../streaming/flume/FlumeStreamSuite.scala | 2 +- .../streaming/mqtt/MQTTStreamSuite.scala | 3 ++- .../org/apache/spark/repl/SparkIMain.scala | 2 +- .../scala/org/apache/spark/repl/Main.scala | 2 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 9 +++----- 14 files changed, 33 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index edc3889c9ae5..677c5e0f89d7 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -24,6 +24,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils private[spark] class HttpFileServer( + conf: SparkConf, securityManager: SecurityManager, requestedPort: Int = 0) extends Logging { @@ -41,7 +42,7 @@ private[spark] class HttpFileServer( fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") + httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server") httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 912558d0cab7..fa22787ce7ea 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -42,6 +42,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * around a Jetty server. */ private[spark] class HttpServer( + conf: SparkConf, resourceBase: File, securityManager: SecurityManager, requestedPort: Int = 0, @@ -57,7 +58,7 @@ private[spark] class HttpServer( } else { logInfo("Starting HTTP Server") val (actualServer, actualPort) = - Utils.startServiceOnPort[Server](requestedPort, doStart, serverName) + Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName) server = actualServer port = actualPort } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c14764f77398..a0ce107f43b1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -370,7 +370,9 @@ private[spark] object SparkConf { } /** - * Return whether the given config is a Spark port config. + * Return true if the given config matches either `spark.*.port` or `spark.port.*`. */ - def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port") + def isSparkPortConf(name: String): Boolean = { + (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") + } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 43436a169700..4d418037bd33 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -312,7 +312,7 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(securityManager, fileServerPort) + val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 31f0a462f84d..31d6958c403b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -153,7 +153,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") + server = + new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 3340fca08014..03c4137ca0a8 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -174,7 +174,7 @@ private[nio] class ConnectionManager( serverChannel.socket.bind(new InetSocketAddress(port)) (serverChannel, serverChannel.socket.getLocalPort) } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) + Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 2a27d49d2de0..88fed833f922 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName) + val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index db2531dc171f..4c9b1e3c46f0 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -53,7 +53,7 @@ private[spark] object AkkaUtils extends Logging { val startService: Int => (ActorSystem, Int) = { actualPort => doCreateActorSystem(name, host, actualPort, conf, securityManager) } - Utils.startServiceOnPort(port, startService, name) + Utils.startServiceOnPort(port, startService, conf, name) } private def doCreateActorSystem( diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 578f1a9cf486..2c04e4ddfbcb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1690,17 +1690,15 @@ private[spark] object Utils extends Logging { } /** - * Default maximum number of retries when binding to a port before giving up. + * Maximum number of retries when binding to a port before giving up. */ - val portMaxRetries: Int = { - if (sys.props.contains("spark.testing")) { + def portMaxRetries(conf: SparkConf): Int = { + val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt) + if (conf.contains("spark.testing")) { // Set a higher number of retries for tests... - sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100) + maxRetries.getOrElse(100) } else { - Option(SparkEnv.get) - .flatMap(_.conf.getOption("spark.port.maxRetries")) - .map(_.toInt) - .getOrElse(16) + maxRetries.getOrElse(16) } } @@ -1709,17 +1707,18 @@ private[spark] object Utils extends Logging { * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). * * @param startPort The initial port to start the service on. - * @param maxRetries Maximum number of retries to attempt. - * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example. * @param startService Function to start service on a given port. * This is expected to throw java.net.BindException on port collision. + * @param conf A SparkConf used to get the maximum number of retries when binding to a port. + * @param serviceName Name of the service. */ def startServiceOnPort[T]( startPort: Int, startService: Int => (T, Int), - serviceName: String = "", - maxRetries: Int = portMaxRetries): (T, Int) = { + conf: SparkConf, + serviceName: String = ""): (T, Int) = { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" + val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port val tryPort = if (startPort == 0) { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 13943ed5442b..f333e3891b5f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -80,7 +80,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, conf)._2 } /** Setup and start the streaming context */ diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 30727dfa6443..fe53a29cba0c 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.SparkConf import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { @@ -106,7 +107,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, new SparkConf())._2 } def publishData(data: String): Unit = { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 646c68e60c2e..b646f0b6f086 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -106,7 +106,7 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 5e93a7199507..69e44d4f916e 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -32,7 +32,7 @@ object Main extends Logging { val s = new Settings() s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var interp = new SparkILoop // this is a public var because tests reset it. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ebf5616e8d30..c537da9f6755 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -148,12 +148,9 @@ class ExecutorRunnable( // registers with the Scheduler and transfers the spark configs. Since the Executor backend // uses Akka to connect to the scheduler, the akka settings are needed as well as the // authentication settings. - sparkConf.getAll. - filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } - - sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + sparkConf.getAll + .filter { case (k, v) => SparkConf.isExecutorStartupConf(k) } + .foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. From 9dea64e53ad8df8a3160c0f4010811af1e73dd6f Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Tue, 13 Jan 2015 09:43:48 -0800 Subject: [PATCH 132/215] [SPARK-4697][YARN]System properties should override environment variables I found some arguments in yarn module take environment variables before system properties while the latter override the former in core module. Author: WangTaoTheTonic Author: WangTao Closes #3557 from WangTaoTheTonic/SPARK4697 and squashes the following commits: 836b9ef [WangTaoTheTonic] fix type mismatch e3e486a [WangTaoTheTonic] remove the comma 1262d57 [WangTaoTheTonic] handle spark.app.name and SPARK_YARN_APP_NAME in SparkSubmitArguments bee9447 [WangTaoTheTonic] wrong brace 81833bb [WangTaoTheTonic] rebase 40934b4 [WangTaoTheTonic] just switch blocks 5f43f45 [WangTao] System property can override environment variable --- .../spark/deploy/SparkSubmitArguments.scala | 5 +++++ .../spark/deploy/yarn/ClientArguments.scala | 4 ++-- .../cluster/YarnClientSchedulerBackend.scala | 19 +++++++++++-------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f14ef4d29938..47059b08a397 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -149,6 +149,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Global defaults. These should be keep to minimum to avoid confusing behavior. master = Option(master).getOrElse("local[*]") + // In YARN mode, app name can be set via SPARK_YARN_APP_NAME (see SPARK-5222) + if (master.startsWith("yarn")) { + name = Option(name).orElse(env.get("SPARK_YARN_APP_NAME")).orNull + } + // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index fdbf9f8eed02..461a9ccd3c21 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -64,12 +64,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051). files = Option(files) - .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p))) + .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) .orNull archives = Option(archives) - .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) + .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) .orNull // If dynamic allocation is enabled, start at the max number of executors if (isDynamicAllocationEnabled) { diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f99291553b7b..690f927e938c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -74,8 +74,7 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,18 +85,22 @@ private[spark] class YarnClientSchedulerBackend( // Do the same for deprecated properties: property -> suggestion val deprecatedProps = Map("spark.master.memory" -> "--driver-memory through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => - if (System.getenv(envVar) != null) { - extraArgs += (optionName, System.getenv(envVar)) - if (deprecatedEnvVars.contains(envVar)) { - logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") - } - } else if (sc.getConf.contains(sparkProp)) { + if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) if (deprecatedProps.contains(sparkProp)) { logWarning(s"NOTE: $sparkProp is deprecated. Use ${deprecatedProps(sparkProp)} instead.") } + } else if (System.getenv(envVar) != null) { + extraArgs += (optionName, System.getenv(envVar)) + if (deprecatedEnvVars.contains(envVar)) { + logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") + } } } + // The app name is a special case because "spark.app.name" is required of all applications. + // As a result, the corresponding "SPARK_YARN_APP_NAME" is already handled preemptively in + // SparkSubmitArguments if "spark.app.name" is not explicitly set by the user. (SPARK-5222) + sc.getConf.getOption("spark.app.name").foreach(v => extraArgs += ("--name", v)) extraArgs } From 39e333ec4350ddafe29ee0958c37eec07bec85df Mon Sep 17 00:00:00 2001 From: uncleGen Date: Tue, 13 Jan 2015 10:07:19 -0800 Subject: [PATCH 133/215] [SPARK-5131][Streaming][DOC]: There is a discrepancy in WAL implementation and configuration doc. There is a discrepancy in WAL implementation and configuration doc. Author: uncleGen Closes #3930 from uncleGen/master-clean-doc and squashes the following commits: 3a4245f [uncleGen] doc typo 8e407d3 [uncleGen] doc typo --- docs/configuration.md | 2 +- docs/streaming-programming-guide.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index f292bfbb7dcd..673cdb371a51 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1228,7 +1228,7 @@ Apart from these, the following properties are also available, and may be useful
    spark.streaming.receiver.writeAheadLogs.enablespark.streaming.receiver.writeAheadLog.enable false Enable write ahead logs for receivers. All the input data received through receivers diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 01450efe35e5..e37a2bb37b9a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1574,7 +1574,7 @@ To run a Spark Streaming applications, you need to have the following. recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) - `spark.streaming.receiver.writeAheadLogs.enable` to `true`. However, these stronger semantics may + `spark.streaming.receiver.writeAheadLog.enable` to `true`. However, these stronger semantics may come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the From 8ead999fd627b12837fb2f082a0e76e9d121d269 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 13 Jan 2015 12:50:31 -0800 Subject: [PATCH 134/215] [SPARK-5223] [MLlib] [PySpark] fix MapConverter and ListConverter in MLlib It will introduce problems if the object in dict/list/tuple can not support by py4j, such as Vector. Also, pickle may have better performance for larger object (less RPC). In some cases that the object in dict/list can not be pickled (such as JavaObject), we should still use MapConvert/ListConvert. This PR should be ported into branch-1.2 Author: Davies Liu Closes #4023 from davies/listconvert and squashes the following commits: 55d4ab2 [Davies Liu] fix MapConverter and ListConverter in MLlib --- python/pyspark/mllib/common.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 33c49e239990..3c5ee66cd8b6 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -18,7 +18,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList +from py4j.java_collections import ListConverter, JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -70,9 +70,7 @@ def _py2java(sc, obj): obj = _to_java_object_rdd(obj) elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, dict): - obj = MapConverter().convert(obj, sc._gateway._gateway_client) - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): obj = ListConverter().convert(obj, sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass From 6463e0b9e8067cce70602c5c9006a2546856a9d6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 13 Jan 2015 13:01:27 -0800 Subject: [PATCH 135/215] [SPARK-4912][SQL] Persistent tables for the Spark SQL data sources api With changes in this PR, users can persist metadata of tables created based on the data source API in metastore through DDLs. Author: Yin Huai Author: Michael Armbrust Closes #3960 from yhuai/persistantTablesWithSchema2 and squashes the following commits: 069c235 [Yin Huai] Make exception messages user friendly. c07cbc6 [Yin Huai] Get the location of test file in a correct way. 4456e98 [Yin Huai] Test data. 5315dfc [Yin Huai] rxin's comments. 7fc4b56 [Yin Huai] Add DDLStrategy and HiveDDLStrategy to plan DDLs based on the data source API. aeaf4b3 [Yin Huai] Add comments. 06f9b0c [Yin Huai] Revert unnecessary changes. feb88aa [Yin Huai] Merge remote-tracking branch 'apache/master' into persistantTablesWithSchema2 172db80 [Yin Huai] Fix unit test. 49bf1ac [Yin Huai] Unit tests. 8f8f1a1 [Yin Huai] [SPARK-4574][SQL] Adding support for defining schema in foreign DDL commands. #3431 f47fda1 [Yin Huai] Unit tests. 2b59723 [Michael Armbrust] Set external when creating tables c00bb1b [Michael Armbrust] Don't use reflection to read options 1ea6e7b [Michael Armbrust] Don't fail when trying to uncache a table that doesn't exist 6edc710 [Michael Armbrust] Add tests. d7da491 [Michael Armbrust] First draft of persistent tables. --- .../org/apache/spark/sql/SQLContext.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 14 + .../apache/spark/sql/execution/commands.scala | 3 +- .../org/apache/spark/sql/sources/ddl.scala | 53 ++-- .../apache/spark/sql/sources/interfaces.scala | 17 +- .../spark/sql/sources/TableScanSuite.scala | 30 +++ .../apache/spark/sql/hive/HiveContext.scala | 12 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 79 +++++- .../spark/sql/hive/HiveStrategies.scala | 11 + .../org/apache/spark/sql/hive/TestHive.scala | 1 + .../spark/sql/hive/execution/commands.scala | 21 ++ sql/hive/src/test/resources/sample.json | 2 + .../sql/hive/MetastoreDataSourcesSuite.scala | 244 ++++++++++++++++++ .../sql/hive/execution/HiveQuerySuite.scala | 1 - 14 files changed, 461 insertions(+), 28 deletions(-) create mode 100644 sql/hive/src/test/resources/sample.json create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6c575dd727b4..e7021cc3366c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -330,6 +330,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def strategies: Seq[Strategy] = extraStrategies ++ ( DataSourceStrategy :: + DDLStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 99b6611d3bbc..d91b1fbc6983 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.sources.{CreateTempTableUsing, CreateTableUsing} import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -310,4 +311,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } } + + object DDLStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, options) => + ExecutedCommand( + CreateTempTableUsing(tableName, userSpecifiedSchema, provider, options)) :: Nil + + case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") + + case _ => Nil + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 0d765c4c92f8..df8e61615104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -177,7 +177,6 @@ case class DescribeCommand( override val output: Seq[Attribute]) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - Row("# Registered as a temporary table", null, null) +: - child.output.map(field => Row(field.name, field.dataType.toString, null)) + child.output.map(field => Row(field.name, field.dataType.toString, null)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index fe2c4d8436b2..f8741e008209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -92,21 +92,21 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val ddl: Parser[LogicalPlan] = createTable /** - * `CREATE TEMPORARY TABLE avroTable + * `CREATE [TEMPORARY] TABLE avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...) + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = ( - CREATE ~ TEMPORARY ~ TABLE ~> ident + (CREATE ~> TEMPORARY.? <~ TABLE) ~ ident ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { - case tableName ~ columns ~ provider ~ opts => + case temp ~ tableName ~ columns ~ provider ~ opts => val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing(tableName, userSpecifiedSchema, provider, opts) + CreateTableUsing(tableName, userSpecifiedSchema, provider, temp.isDefined, opts) } ) @@ -175,13 +175,12 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi primitiveType } -private[sql] case class CreateTableUsing( - tableName: String, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]) extends RunnableCommand { - - def run(sqlContext: SQLContext) = { +object ResolvedDataSource { + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]): ResolvedDataSource = { val loader = Utils.getContextOrSparkClassLoader val clazz: Class[_] = try loader.loadClass(provider) catch { case cnf: java.lang.ClassNotFoundException => @@ -199,22 +198,44 @@ private[sql] case class CreateTableUsing( .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case _ => - sys.error(s"${clazz.getCanonicalName} should extend SchemaRelationProvider.") + sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") } } case None => { clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => + case dataSource: org.apache.spark.sql.sources.RelationProvider => dataSource .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] .createRelation(sqlContext, new CaseInsensitiveMap(options)) case _ => - sys.error(s"${clazz.getCanonicalName} should extend RelationProvider.") + sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") } } } - sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + new ResolvedDataSource(clazz, relation) + } +} + +private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + +private[sql] case class CreateTableUsing( + tableName: String, + userSpecifiedSchema: Option[StructType], + provider: String, + temporary: Boolean, + options: Map[String, String]) extends Command + +private [sql] case class CreateTempTableUsing( + tableName: String, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) + + sqlContext.baseRelationToSchemaRDD(resolved.relation).registerTempTable(tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 990f7e0e74bc..2a7be23e37c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When - * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to - * pass in the parameters specified by a user. + * Spark SQL is given a DDL operation with a USING clause specified (to specify the implemented + * RelationProvider), this interface is used to pass in the parameters specified by a user. * * Users may specify the fully qualified class name of a given data source. When that class is * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for @@ -46,10 +46,10 @@ trait RelationProvider { /** * ::DeveloperApi:: - * Implemented by objects that produce relations for a specific kind of data source. When - * Spark SQL is given a DDL operation with - * 1. USING clause: to specify the implemented SchemaRelationProvider - * 2. User defined schema: users can define schema optionally when create table + * Implemented by objects that produce relations for a specific kind of data source + * with a given schema. When Spark SQL is given a DDL operation with a USING clause specified ( + * to specify the implemented SchemaRelationProvider) and a user defined schema, this interface + * is used to pass in the parameters specified by a user. * * Users may specify the fully qualified class name of a given data source. When that class is * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for @@ -57,6 +57,11 @@ trait RelationProvider { * data source 'org.apache.spark.sql.json.DefaultSource' * * A new instance of this class with be instantiated each time a DDL call is made. + * + * The difference between a [[RelationProvider]] and a [[SchemaRelationProvider]] is that + * users need to provide a schema when using a SchemaRelationProvider. + * A relation provider can inherits both [[RelationProvider]] and [[SchemaRelationProvider]] + * if it can support both schema inference and user-specified schemas. */ @DeveloperApi trait SchemaRelationProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 605190f5ae6a..a1d2468b2573 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -314,4 +314,34 @@ class TableScanSuite extends DataSourceTest { sql("SELECT * FROM oneToTenDef"), (1 to 10).map(Row(_)).toSeq) } + + test("exceptions") { + // Make sure we do throw correct exception when users use a relation provider that + // only implements the RelationProvier or the SchemaRelationProvider. + val schemaNotAllowed = intercept[Exception] { + sql( + """ + |CREATE TEMPORARY TABLE relationProvierWithSchema (i int) + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNotAllowed.getMessage.contains("does not allow user-specified schemas")) + + val schemaNeeded = intercept[Exception] { + sql( + """ + |CREATE TEMPORARY TABLE schemaRelationProvierWithoutSchema + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | From '1', + | To '10' + |) + """.stripMargin) + } + assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 02eac43b2103..09ff4cc5ab43 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -115,6 +115,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } + def refreshTable(tableName: String): Unit = { + // TODO: Database support... + catalog.refreshTable("default", tableName) + } + + protected[hive] def invalidateTable(tableName: String): Unit = { + // TODO: Database support... + catalog.invalidateTable("default", tableName) + } + /** * Analyzes the given table in the current database to generate statistics, which will be * used in query optimizations. @@ -340,6 +350,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def strategies: Seq[Strategy] = extraStrategies ++ Seq( DataSourceStrategy, HiveCommandStrategy(self), + HiveDDLStrategy, + DDLStrategy, TakeOrdered, ParquetOperations, InMemoryScans, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c25288e00012..daeabb6c8bab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} + import org.apache.hadoop.util.ReflectionUtils import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException} import org.apache.hadoop.hive.ql.metadata.InvalidTableException import org.apache.hadoop.hive.ql.plan.CreateTableDesc @@ -39,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.util.Utils /* Implicit conversions */ @@ -50,8 +52,76 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) + // TODO: Use this everywhere instead of tuples or databaseName, tableName,. + /** A fully qualified identifier for a table (i.e., database.tableName) */ + case class QualifiedTableName(database: String, name: String) { + def toLowerCase = QualifiedTableName(database.toLowerCase, name.toLowerCase) + } + + /** A cache of Spark SQL data source tables that have been accessed. */ + protected[hive] val cachedDataSourceTables: LoadingCache[QualifiedTableName, LogicalPlan] = { + val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { + override def load(in: QualifiedTableName): LogicalPlan = { + logDebug(s"Creating new cached data source for $in") + val table = client.getTable(in.database, in.name) + val schemaString = table.getProperty("spark.sql.sources.schema") + val userSpecifiedSchema = + if (schemaString == null) { + None + } else { + Some(DataType.fromJson(schemaString).asInstanceOf[StructType]) + } + // It does not appear that the ql client for the metastore has a way to enumerate all the + // SerDe properties directly... + val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap + + val resolvedRelation = + ResolvedDataSource( + hive, + userSpecifiedSchema, + table.getProperty("spark.sql.sources.provider"), + options) + + LogicalRelation(resolvedRelation.relation) + } + } + + CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) + } + + def refreshTable(databaseName: String, tableName: String): Unit = { + cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase) + } + + def invalidateTable(databaseName: String, tableName: String): Unit = { + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) + } + val caseSensitive: Boolean = false + def createDataSourceTable( + tableName: String, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]) = { + val (dbName, tblName) = processDatabaseAndTableName("default", tableName) + val tbl = new Table(dbName, tblName) + + tbl.setProperty("spark.sql.sources.provider", provider) + if (userSpecifiedSchema.isDefined) { + tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json) + } + options.foreach { case (key, value) => tbl.setSerdeParam(key, value) } + + tbl.setProperty("EXTERNAL", "TRUE") + tbl.setTableType(TableType.EXTERNAL_TABLE) + + // create the table + synchronized { + client.createTable(tbl, false) + } + } + def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( @@ -72,7 +142,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last val table = client.getTable(databaseName, tblName) - if (table.isView) { + + if (table.getProperty("spark.sql.sources.provider") != null) { + cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) + } else if (table.isView) { // if the unresolved relation is from hive view // parse the text into logic node. HiveQl.createPlanForView(table, alias) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index c439b9ebfe10..cdff82e3d04d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.CreateTableUsing import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import scala.collection.JavaConversions._ @@ -208,6 +209,16 @@ private[hive] trait HiveStrategies { } } + object HiveDDLStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, options) => + ExecutedCommand( + CreateMetastoreDataSource(tableName, userSpecifiedSchema, provider, options)) :: Nil + + case _ => Nil + } + } + case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case describe: DescribeCommand => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 1358a0eccb35..31c7ce96398e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -395,6 +395,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { clearCache() loadedTables.clear() + catalog.cachedDataSourceTables.invalidateAll() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") val table = catalog.client.getTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 6b733a280e6d..e70cdeaad4c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.SQLContext @@ -52,6 +53,12 @@ case class DropTable( override def run(sqlContext: SQLContext) = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" + try { + hiveContext.tryUncacheQuery(hiveContext.table(tableName)) + } catch { + case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => + } + hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(Seq(tableName)) Seq.empty[Row] @@ -85,3 +92,17 @@ case class AddFile(path: String) extends RunnableCommand { Seq.empty[Row] } } + +case class CreateMetastoreDataSource( + tableName: String, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]) extends RunnableCommand { + + override def run(sqlContext: SQLContext) = { + val hiveContext = sqlContext.asInstanceOf[HiveContext] + hiveContext.catalog.createDataSourceTable(tableName, userSpecifiedSchema, provider, options) + + Seq.empty[Row] + } +} diff --git a/sql/hive/src/test/resources/sample.json b/sql/hive/src/test/resources/sample.json new file mode 100644 index 000000000000..a2c2ffd5e033 --- /dev/null +++ b/sql/hive/src/test/resources/sample.json @@ -0,0 +1,2 @@ +{"a" : "2" ,"b" : "blah", "c_!@(3)":1} +{"" : {"d!" : [4, 5], "=" : [{"Dd2": null}, {"Dd2" : true}]}} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala new file mode 100644 index 000000000000..ec9ebb4a775a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.scalatest.BeforeAndAfterEach + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql._ +import org.apache.spark.util.Utils + +/* Implicits */ +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * Tests for persisting tables created though the data sources API into the metastore. + */ +class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { + override def afterEach(): Unit = { + reset() + } + + val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + + test ("persistent JSON table") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + jsonFile(filePath).collect().toSeq) + } + + test ("persistent JSON table with a user specified schema") { + sql( + s""" + |CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + jsonFile(filePath).registerTempTable("expectedJsonTable") + + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + } + + test ("persistent JSON table with a user specified schema with a subset of fields") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s""" + |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + val innerStruct = StructType( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) + val expectedSchema = StructType( + StructField("", innerStruct, true) :: + StructField("b", StringType, true) :: Nil) + + assert(expectedSchema == table("jsonTable").schema) + + jsonFile(filePath).registerTempTable("expectedJsonTable") + + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + } + + test("resolve shortened provider names") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + jsonFile(filePath).collect().toSeq) + } + + test("drop table") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + jsonFile(filePath).collect().toSeq) + + sql("DROP TABLE jsonTable") + + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } + } + + test("check change without refresh") { + val tempDir = File.createTempFile("sparksql", "json") + tempDir.delete() + sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + ("a", "b") :: Nil) + + FileUtils.deleteDirectory(tempDir) + sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + ("a1", "b1") :: Nil) + + refreshTable("jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + ("a1", "b1", "c1") :: Nil) + FileUtils.deleteDirectory(tempDir) + } + + test("drop, change, recreate") { + val tempDir = File.createTempFile("sparksql", "json") + tempDir.delete() + sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + ("a", "b") :: Nil) + + FileUtils.deleteDirectory(tempDir) + sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + ("a", "b", "c") :: Nil) + FileUtils.deleteDirectory(tempDir) + } + + test("invalidate cache and reload") { + sql( + s""" + |CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '${filePath}' + |) + """.stripMargin) + + jsonFile(filePath).registerTempTable("expectedJsonTable") + + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + + // Discard the cached relation. + invalidateTable("jsonTable") + + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + + assert(expectedSchema == table("jsonTable").schema) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 700a45edb11d..4decd1548534 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -623,7 +623,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertResult( Array( - Array("# Registered as a temporary table", null, null), Array("a", "IntegerType", null), Array("b", "StringType", null)) ) { From 14e3f114efb906937b2d7b7ac04484b2814a3b48 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 13 Jan 2015 13:30:35 -0800 Subject: [PATCH 136/215] [SPARK-5168] Make SQLConf a field rather than mixin in SQLContext This change should be binary and source backward compatible since we didn't change any user facing APIs. Author: Reynold Xin Closes #3965 from rxin/SPARK-5168-sqlconf and squashes the following commits: 42eec09 [Reynold Xin] Fix default conf value. 0ef86cc [Reynold Xin] Fix constructor ordering. 4d7f910 [Reynold Xin] Properly override config. ccc8e6a [Reynold Xin] [SPARK-5168] Make SQLConf a field rather than mixin in SQLContext --- .../org/apache/spark/sql/CacheManager.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 40 +++++++++++++++---- .../spark/sql/api/java/JavaSQLContext.scala | 6 +-- .../columnar/InMemoryColumnarTableScan.scala | 4 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../spark/sql/execution/ExistingRDD.scala | 4 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 18 ++++----- .../apache/spark/sql/execution/commands.scala | 2 +- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../apache/spark/sql/json/JSONRelation.scala | 4 +- .../spark/sql/parquet/ParquetRelation.scala | 7 ++-- .../spark/sql/parquet/ParquetTest.scala | 2 +- .../apache/spark/sql/parquet/newParquet.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/test/TestSQLContext.scala | 5 ++- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../org/apache/spark/sql/SQLConfSuite.scala | 10 ++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 3 +- .../columnar/PartitionBatchPruningSuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 2 +- .../spark/sql/parquet/ParquetIOSuite.scala | 4 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 14 +++---- .../execution/HiveCompatibilitySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 11 ++--- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/TestHive.scala | 6 ++- .../sql/hive/api/java/JavaHiveContext.scala | 6 +-- .../spark/sql/hive/StatisticsSuite.scala | 20 +++++----- .../sql/hive/execution/HiveQuerySuite.scala | 4 +- 33 files changed, 124 insertions(+), 92 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3c9439b2e9a5..e715d9434a2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -91,8 +91,8 @@ private[sql] trait CacheManager { CachedData( planToCache, InMemoryRelation( - useCompression, - columnBatchSize, + conf.useCompression, + conf.columnBatchSize, storageLevel, query.queryExecution.executedPlan, tableName)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f5bf935522da..206d16f5b34f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -61,7 +61,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] trait SQLConf { +private[sql] class SQLConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e7021cc3366c..d8efce0cb43e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.util.Properties + +import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -49,7 +52,6 @@ import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging - with SQLConf with CacheManager with ExpressionConversions with UDFRegistration @@ -57,6 +59,30 @@ class SQLContext(@transient val sparkContext: SparkContext) self => + // Note that this is a lazy val so we can override the default value in subclasses. + private[sql] lazy val conf: SQLConf = new SQLConf + + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = conf.setConf(props) + + /** Set the given Spark SQL configuration property. */ + def setConf(key: String, value: String): Unit = conf.setConf(key, value) + + /** Return the value of Spark SQL configuration property for the given key. */ + def getConf(key: String): String = conf.getConf(key) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. + */ + def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) + + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) @@ -212,7 +238,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { - val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( JsonRDD.nullTypeToStringType( @@ -226,7 +252,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) @@ -299,10 +325,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def sql(sqlText: String): SchemaRDD = { - if (dialect == "sql") { + if (conf.dialect == "sql") { new SchemaRDD(this, parseSql(sqlText)) } else { - sys.error(s"Unsupported SQL dialect: $dialect") + sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } @@ -323,9 +349,9 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self - def codegenEnabled = self.codegenEnabled + def codegenEnabled = self.conf.codegenEnabled - def numPartitions = self.numShufflePartitions + def numPartitions = self.conf.numShufflePartitions def strategies: Seq[Strategy] = extraStrategies ++ ( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 8884204e5079..7f868cd4afca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -52,7 +52,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * @group userf */ def sql(sqlText: String): JavaSchemaRDD = { - if (sqlContext.dialect == "sql") { + if (sqlContext.conf.dialect == "sql") { new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: $sqlContext.dialect") @@ -164,7 +164,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord val appliedScalaSchema = JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord)) @@ -182,7 +182,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { */ @Experimental def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { - val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( JsonRDD.nullTypeToStringType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 1e432485c4c2..065fae3c83df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -82,7 +82,7 @@ private[sql] case class InMemoryRelation( if (batchStats.value.isEmpty) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) } else { // Underlying columnar RDD has been materialized, required information has also been collected // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. @@ -233,7 +233,7 @@ private[sql] case class InMemoryColumnarTableScan( val readPartitions = sparkContext.accumulator(0) val readBatches = sparkContext.accumulator(0) - private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning + private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning override def execute() = { readPartitions.setValue(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d7c811ca8902..7c0b72aab448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -123,7 +123,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions = sqlContext.numShufflePartitions + def numPartitions = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d2d8cb1c62d4..069e95019530 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -69,7 +69,7 @@ case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLCont @transient override lazy val statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) ) } @@ -106,6 +106,6 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ @transient override lazy val statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + sizeInBytes = BigInt(sqlContext.conf.defaultSizeInBytes) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 017c78d2c66d..6fecd1ff066c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -51,7 +51,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ // sqlContext will be null when we are being deserialized on the slaves. In this instance // the value of codegenEnabled will be set by the desserializer after the constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { - sqlContext.codegenEnabled + sqlContext.conf.codegenEnabled } else { false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d91b1fbc6983..0652d2ff7c9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,8 +35,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if sqlContext.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => val semiJoin = joins.BroadcastLeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil @@ -81,13 +81,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => + if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => @@ -215,7 +215,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = - if (sqlContext.parquetFilterPushDown) { + if (sqlContext.conf.parquetFilterPushDown) { (predicates: Seq[Expression]) => { // Note: filters cannot be pushed down to Parquet if they contain more complex // expressions than simple "Attribute cmp Literal" comparisons. Here we remove all @@ -237,7 +237,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ParquetTableScan( _, relation, - if (sqlContext.parquetFilterPushDown) filters else Nil)) :: Nil + if (sqlContext.conf.parquetFilterPushDown) filters else Nil)) :: Nil case _ => Nil } @@ -270,7 +270,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. execution.Sort(sortExprs, global = false, planLater(child)) :: Nil - case logical.Sort(sortExprs, global, child) if sqlContext.externalSortEnabled => + case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled => execution.ExternalSort(sortExprs, global, planLater(child)):: Nil case logical.Sort(sortExprs, global, child) => execution.Sort(sortExprs, global, planLater(child)):: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index df8e61615104..af6b07bd6c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -94,7 +94,7 @@ case class SetCommand( logWarning( s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.numShufflePartitions}")) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}")) // Queries a single property. case Some((key, None)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index fbe1d06ed2e8..2dd22c020ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -43,7 +43,7 @@ case class BroadcastHashJoin( extends BinaryNode with HashJoin { val timeout = { - val timeoutValue = sqlContext.broadcastTimeout + val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { Duration.Inf } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index a9a6696cb15e..f5c02224c82a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -59,8 +59,8 @@ private[sql] case class JSONRelation( JsonRDD.inferSchema( baseRDD, samplingRatio, - sqlContext.columnNameOfCorruptRecord))) + sqlContext.conf.columnNameOfCorruptRecord))) override def buildScan() = - JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 2835dc3408b9..cde5160149e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -65,7 +65,7 @@ private[sql] case class ParquetRelation( ParquetTypesConverter.readSchemaFromFile( new Path(path.split(",").head), conf, - sqlContext.isParquetBinaryAsString) + sqlContext.conf.isParquetBinaryAsString) lazy val attributeMap = AttributeMap(output.map(o => o -> o)) @@ -80,7 +80,7 @@ private[sql] case class ParquetRelation( } // TODO: Use data from the footers. - override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) + override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -163,7 +163,8 @@ private[sql] object ParquetRelation { sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( - sqlContext.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) + sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) + .name()) ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) new ParquetRelation(path.toString, Some(conf), sqlContext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index b4d48902fd2c..02ce1b3e6d81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -54,7 +54,7 @@ trait ParquetTest { try f finally { keys.zip(currentValues).foreach { case (key, Some(value)) => setConf(key, value) - case (key, None) => unsetConf(key) + case (key, None) => conf.unsetConf(key) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 2e0c6c51c00e..55a2728a85cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -137,7 +137,7 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) ParquetTypesConverter.readSchemaFromFile( partitions.head.files.head.getPath, Some(sparkContext.hadoopConfiguration), - sqlContext.isParquetBinaryAsString)) + sqlContext.conf.isParquetBinaryAsString)) val dataIncludesKey = partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) @@ -198,7 +198,7 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) predicates .reduceOption(And) .flatMap(ParquetFilters.createFilter) - .filter(_ => sqlContext.parquetFilterPushDown) + .filter(_ => sqlContext.conf.parquetFilterPushDown) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2a7be23e37c7..7f5564baa00f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -99,7 +99,7 @@ abstract class BaseRelation { * large to broadcast. This method will be called multiple times during query planning * and thus should not perform expensive operations for each invocation. */ - def sizeInBytes = sqlContext.defaultSizeInBytes + def sizeInBytes = sqlContext.conf.defaultSizeInBytes } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 6bb81c76ed8b..8c80be106f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -29,6 +29,7 @@ object TestSQLContext new SparkConf().set("spark.sql.testkey", "true"))) { /** Fewer partitions to speed up testing. */ - override private[spark] def numShufflePartitions: Int = - getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + private[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index c7e136388fce..e5ab16f9dd66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -387,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { clearCache() sql("CACHE TABLE testData") - val tmp = autoBroadcastJoinThreshold + val tmp = conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 60701f0e154f..bf73d0c7074a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -37,7 +37,7 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { } test("programmatic ways of basic setting and getting") { - clear() + conf.clear() assert(getAllConfs.size === 0) setConf(testKey, testVal) @@ -51,11 +51,11 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) assert(TestSQLContext.getAllConfs.contains(testKey)) - clear() + conf.clear() } test("parse SQL set commands") { - clear() + conf.clear() sql(s"set $testKey=$testVal") assert(getConf(testKey, testVal + "_") == testVal) assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) @@ -73,11 +73,11 @@ class SQLConfSuite extends QueryTest with FunSuiteLike { sql(s"set $key=") assert(getConf(key, "0") == "") - clear() + conf.clear() } test("deprecated property") { - clear() + conf.clear() sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d9de5686dce4..bc72daf0880a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -79,7 +79,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("aggregation with codegen") { - val originalValue = codegenEnabled + val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") sql("SELECT key FROM testData GROUP BY key").collect() setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) @@ -245,14 +245,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("sorting") { - val before = externalSortEnabled + val before = conf.externalSortEnabled setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = externalSortEnabled + val before = conf.externalSortEnabled setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() setConf(SQLConf.EXTERNAL_SORT, before.toString) @@ -600,7 +600,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SET commands semantics using sql()") { - clear() + conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -632,7 +632,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql(s"SET $nonexistentKey"), Seq(Seq(s"$nonexistentKey=")) ) - clear() + conf.clear() } test("apply schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index fc95dccc74e2..d94729ba9236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -39,7 +39,8 @@ class InMemoryColumnarQuerySuite extends QueryTest { sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst") cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.logical.statistics.sizeInBytes > autoBroadcastJoinThreshold) + table("sizeTst").queryExecution.logical.statistics.sizeInBytes > + conf.autoBroadcastJoinThreshold) } test("projection") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 1915c25392f1..592cafbbdc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = columnBatchSize - val originalInMemoryPartitionPruning = inMemoryPartitionPruning + val originalColumnBatchSize = conf.columnBatchSize + val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a5af71acfc79..c5b6fce5fd29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -60,7 +60,7 @@ class PlannerSuite extends FunSuite { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - val origThreshold = autoBroadcastJoinThreshold + val origThreshold = conf.autoBroadcastJoinThreshold setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) // Using a threshold that is definitely larger than the small testing table (b) below @@ -78,7 +78,7 @@ class PlannerSuite extends FunSuite { } test("InMemoryRelation statistics propagation") { - val origThreshold = autoBroadcastJoinThreshold + val origThreshold = conf.autoBroadcastJoinThreshold setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) testData.limit(3).registerTempTable("tiny") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 8dce3372a8db..b09f1ac49553 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -713,7 +713,7 @@ class JsonSuite extends QueryTest { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord + val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val jsonSchemaRDD = jsonRDD(corruptRecords) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 10a01474e95b..6ac67fcafe16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -213,7 +213,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { withParquetFile(data) { path => - assertResult(parquetCompressionCodec.toUpperCase) { + assertResult(conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -221,7 +221,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index a5fe2e8da284..0a92336a3cb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -88,7 +88,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA TestData // Load test data tables. private var testRDD: SchemaRDD = null - private val originalParquetFilterPushdownEnabled = TestSQLContext.parquetFilterPushDown + private val originalParquetFilterPushdownEnabled = TestSQLContext.conf.parquetFilterPushDown override def beforeAll() { ParquetTestData.writeFile() @@ -144,7 +144,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } ignore("Treat binary as string") { - val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString + val oldIsParquetBinaryAsString = TestSQLContext.conf.isParquetBinaryAsString // Create the test file. val file = getTempFilePath("parquet") @@ -174,7 +174,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Compression options for writing to a Parquetfile") { - val defaultParquetCompressionCodec = TestSQLContext.parquetCompressionCodec + val defaultParquetCompressionCodec = TestSQLContext.conf.parquetCompressionCodec import scala.collection.JavaConversions._ val file = getTempFilePath("parquet") @@ -186,7 +186,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) var actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -202,7 +202,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -234,7 +234,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) parquetFile(path).registerTempTable("tmp") checkAnswer( @@ -250,7 +250,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA rdd.saveAsParquetFile(path) actualCodec = ParquetTypesConverter.readMetaData(new Path(path), Some(TestSQLContext.sparkContext.hadoopConfiguration)) .getBlocks.flatMap(block => block.getColumns).map(column => column.getCodec.name()).distinct - assert(actualCodec === TestSQLContext.parquetCompressionCodec.toUpperCase :: Nil) + assert(actualCodec === TestSQLContext.conf.parquetCompressionCodec.toUpperCase :: Nil) parquetFile(path).registerTempTable("tmp") checkAnswer( diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 23283fd3fe6b..0d934620aca0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -36,8 +36,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private val originalColumnBatchSize = TestHive.columnBatchSize - private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning + private val originalColumnBatchSize = TestHive.conf.columnBatchSize + private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 09ff4cc5ab43..9aeebd7e5436 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -71,8 +71,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => - // Change the default SQL dialect to HiveQL - override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + private[sql] override lazy val conf: SQLConf = new SQLConf { + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + } /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe @@ -87,12 +88,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def sql(sqlText: String): SchemaRDD = { // TODO: Create a framework for registering parsers instead of just hardcoding if statements. - if (dialect == "sql") { + if (conf.dialect == "sql") { super.sql(sqlText) - } else if (dialect == "hiveql") { + } else if (conf.dialect == "hiveql") { new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) } else { - sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") + sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index daeabb6c8bab..785a6a14f49f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -515,7 +515,7 @@ private[hive] case class MetastoreRelation // if the size is still less than zero, we use default size Option(totalSize).map(_.toLong).filter(_ > 0) .getOrElse(Option(rawDataSize).map(_.toLong).filter(_ > 0) - .getOrElse(sqlContext.defaultSizeInBytes))) + .getOrElse(sqlContext.conf.defaultSizeInBytes))) } ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 31c7ce96398e..52e1f0d94fbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -102,8 +102,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { new this.QueryExecution { val logical = plan } /** Fewer partitions to speed up testing. */ - override private[spark] def numShufflePartitions: Int = - getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + private[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + } /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala index 1817c7832490..038f63f6c744 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -31,12 +31,12 @@ class JavaHiveContext(sqlContext: SQLContext) extends JavaSQLContext(sqlContext) override def sql(sqlText: String): JavaSchemaRDD = { // TODO: Create a framework for registering parsers instead of just hardcoding if statements. - if (sqlContext.dialect == "sql") { + if (sqlContext.conf.dialect == "sql") { super.sql(sqlText) - } else if (sqlContext.dialect == "hiveql") { + } else if (sqlContext.conf.dialect == "hiveql") { new JavaSchemaRDD(sqlContext, HiveQl.parseSql(sqlText)) } else { - sys.error(s"Unsupported SQL dialect: ${sqlContext.dialect}. Try 'sql' or 'hiveql'") + sys.error(s"Unsupported SQL dialect: ${sqlContext.conf.dialect}. Try 'sql' or 'hiveql'") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a758f921e041..0b4e76c9d3d2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -81,7 +81,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // TODO: How does it works? needs to add it back for other hive version. if (HiveShim.version =="0.12.0") { - assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) + assert(queryTotalSize("analyzeTable") === conf.defaultSizeInBytes) } sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") @@ -110,7 +110,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -151,8 +151,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = rdd.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold - && sizes(1) <= autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold + && sizes(1) <= conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -163,8 +163,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(rdd, expectedAnswer) // check correctness of output - TestHive.settings.synchronized { - val tmp = autoBroadcastJoinThreshold + TestHive.conf.settings.synchronized { + val tmp = conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") rdd = sql(query) @@ -207,8 +207,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold - && sizes(0) <= autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold + && sizes(0) <= conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -221,8 +221,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(rdd, answer) // check correctness of output - TestHive.settings.synchronized { - val tmp = autoBroadcastJoinThreshold + TestHive.conf.settings.synchronized { + val tmp = conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") rdd = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 4decd1548534..c14f0d24e0dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -847,7 +847,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet - clear() + conf.clear() // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... @@ -879,7 +879,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { collectResults(sql(s"SET $nonexistentKey")) } - clear() + conf.clear() } createQueryTest("select from thrift based table", From f9969098c8cb15e36c718b80c6cf5b534a6cf7c3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 13 Jan 2015 17:16:41 -0800 Subject: [PATCH 137/215] [SPARK-5123][SQL] Reconcile Java/Scala API for data types. Having two versions of the data type APIs (one for Java, one for Scala) requires downstream libraries to also have two versions of the APIs if the library wants to support both Java and Scala. I took a look at the Scala version of the data type APIs - it can actually work out pretty well for Java out of the box. As part of the PR, I created a sql.types package and moved all type definitions there. I then removed the Java specific data type API along with a lot of the conversion code. This subsumes https://github.com/apache/spark/pull/3925 Author: Reynold Xin Closes #3958 from rxin/SPARK-5123-datatype-2 and squashes the following commits: 66505cc [Reynold Xin] [SPARK-5123] Expose only one version of the data type APIs (i.e. remove the Java-specific API). --- .../scala/org/apache/spark/ml/Pipeline.scala | 5 +- .../org/apache/spark/ml/Transformer.scala | 2 +- .../classification/LogisticRegression.scala | 1 + .../BinaryClassificationEvaluator.scala | 3 +- .../apache/spark/ml/feature/HashingTF.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 1 + .../apache/spark/ml/feature/Tokenizer.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 3 +- .../apache/spark/mllib/linalg/Vectors.scala | 3 +- project/MimaExcludes.scala | 12 + project/SparkBuild.scala | 4 +- sql/README.md | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 5 +- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 4 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 3 +- .../spark/sql/catalyst/expressions/Rand.scala | 2 +- .../spark/sql/catalyst/expressions/Row.scala | 2 +- .../sql/catalyst/expressions/ScalaUdf.scala | 3 +- .../expressions/SpecificMutableRow.scala | 2 +- .../catalyst/expressions/WrapDynamic.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 10 +- .../codegen/GenerateOrdering.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../catalyst/expressions/complexTypes.scala | 2 +- .../expressions/decimalFunctions.scala | 4 +- .../sql/catalyst/expressions/generators.scala | 2 +- .../sql/catalyst/expressions/literals.scala | 4 +- .../expressions/namedExpressions.scala | 3 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../spark/sql/catalyst/expressions/sets.scala | 2 +- .../expressions/stringOperations.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../apache/spark/sql/catalyst/package.scala | 2 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 17 +- .../plans/logical/basicOperators.scala | 2 +- .../plans/physical/partitioning.scala | 2 +- .../spark/sql/types/DataTypeConversions.scala | 68 ++++ .../apache/spark/sql/types/DataTypes.java} | 50 +-- .../{catalyst/util => types}/Metadata.scala | 16 +- .../SQLUserDefinedType.java | 3 +- .../sql/{catalyst => }/types/dataTypes.scala | 294 +++++++++++++-- .../types/decimal/Decimal.scala | 2 +- .../sql/{catalyst => }/types/package.scala | 3 +- .../sql/catalyst/ScalaReflectionSuite.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 2 +- .../analysis/HiveTypeCoercionSuite.scala | 2 +- .../ExpressionEvaluationSuite.scala | 10 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 2 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../sql/catalyst/util/MetadataSuite.scala | 2 + .../spark/sql/types}/DataTypeSuite.scala | 2 +- .../types/decimal/DecimalSuite.scala | 2 +- .../apache/spark/sql/api/java/ArrayType.java | 68 ---- .../apache/spark/sql/api/java/BinaryType.java | 27 -- .../spark/sql/api/java/BooleanType.java | 27 -- .../apache/spark/sql/api/java/ByteType.java | 27 -- .../apache/spark/sql/api/java/DateType.java | 27 -- .../spark/sql/api/java/DecimalType.java | 79 ---- .../apache/spark/sql/api/java/DoubleType.java | 27 -- .../apache/spark/sql/api/java/FloatType.java | 27 -- .../spark/sql/api/java/IntegerType.java | 27 -- .../apache/spark/sql/api/java/LongType.java | 27 -- .../apache/spark/sql/api/java/MapType.java | 78 ---- .../apache/spark/sql/api/java/Metadata.java | 31 -- .../spark/sql/api/java/MetadataBuilder.java | 28 -- .../apache/spark/sql/api/java/NullType.java | 27 -- .../apache/spark/sql/api/java/ShortType.java | 27 -- .../apache/spark/sql/api/java/StringType.java | 27 -- .../spark/sql/api/java/StructField.java | 91 ----- .../apache/spark/sql/api/java/StructType.java | 58 --- .../spark/sql/api/java/TimestampType.java | 27 -- .../spark/sql/api/java/UserDefinedType.java | 54 --- .../org/apache/spark/sql/SQLContext.scala | 9 +- .../org/apache/spark/sql/SchemaRDD.scala | 6 +- .../org/apache/spark/sql/SparkSQLParser.scala | 4 +- .../spark/sql/api/java/JavaSQLContext.scala | 69 ++-- .../spark/sql/api/java/JavaSchemaRDD.scala | 10 +- .../org/apache/spark/sql/api/java/Row.scala | 2 - .../spark/sql/api/java/UDFRegistration.scala | 139 ++++--- .../spark/sql/api/java/UDTWrappers.scala | 75 ---- .../spark/sql/columnar/ColumnAccessor.scala | 4 +- .../spark/sql/columnar/ColumnBuilder.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 2 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../CompressibleColumnAccessor.scala | 2 +- .../CompressibleColumnBuilder.scala | 2 +- .../compression/CompressionScheme.scala | 2 +- .../compression/compressionSchemes.scala | 3 +- .../spark/sql/execution/ExistingRDD.scala | 3 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../sql/execution/SparkSqlSerializer.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/debug/package.scala | 2 +- .../spark/sql/execution/pythonUdfs.scala | 4 +- .../apache/spark/sql/json/JSONRelation.scala | 3 +- .../org/apache/spark/sql/json/JsonRDD.scala | 8 +- .../scala/org/apache/spark/sql/package.scala | 349 ------------------ .../spark/sql/parquet/ParquetConverter.scala | 17 +- .../spark/sql/parquet/ParquetFilters.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/parquet/ParquetTypes.scala | 16 +- .../apache/spark/sql/parquet/newParquet.scala | 7 +- .../spark/sql/sources/LogicalRelation.scala | 4 +- .../org/apache/spark/sql/sources/ddl.scala | 10 +- .../apache/spark/sql/sources/interfaces.scala | 3 +- .../spark/sql/test/ExamplePointUDT.scala | 4 +- .../sql/types/util/DataTypeConversions.scala | 175 --------- .../spark/sql/api/java/JavaAPISuite.java | 9 +- .../sql/api/java/JavaApplySchemaSuite.java | 23 +- .../java/JavaSideDataTypeConversionSuite.java | 150 -------- .../org/apache/spark/sql/DslQuerySuite.scala | 1 + .../scala/org/apache/spark/sql/RowSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 3 +- .../sql/ScalaReflectionRelationSuite.scala | 1 - .../spark/sql/UserDefinedTypeSuite.scala | 3 +- .../spark/sql/api/java/JavaSQLSuite.scala | 9 +- .../ScalaSideDataTypeConversionSuite.scala | 89 ----- .../spark/sql/columnar/ColumnStatsSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 2 +- .../sql/columnar/ColumnarTestUtils.scala | 2 +- .../NullableColumnAccessorSuite.scala | 2 +- .../columnar/NullableColumnBuilderSuite.scala | 2 +- .../compression/DictionaryEncodingSuite.scala | 2 +- .../compression/IntegralDeltaSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 2 +- .../TestCompressibleColumnBuilder.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 16 +- .../spark/sql/parquet/ParquetIOSuite.scala | 6 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 16 +- .../sql/parquet/ParquetSchemaSuite.scala | 21 +- .../spark/sql/sources/FilteredScanSuite.scala | 2 + .../spark/sql/sources/PrunedScanSuite.scala | 1 + .../spark/sql/sources/TableScanSuite.scala | 2 +- .../spark/sql/hive/thriftserver/Shim12.scala | 4 +- .../spark/sql/hive/thriftserver/Shim13.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveInspectors.scala | 15 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 5 +- .../spark/sql/hive/HiveStrategies.scala | 7 +- .../hive/execution/HiveNativeCommand.scala | 2 +- .../sql/hive/execution/HiveTableScan.scala | 4 +- .../spark/sql/hive/execution/commands.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 19 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 2 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 1 + .../sql/hive/MetastoreDataSourcesSuite.scala | 1 + .../org/apache/spark/sql/hive/Shim12.scala | 20 +- .../org/apache/spark/sql/hive/Shim13.scala | 11 +- 160 files changed, 756 insertions(+), 2096 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala rename sql/{core/src/main/java/org/apache/spark/sql/api/java/DataType.java => catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java} (81%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst/util => types}/Metadata.scala (96%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst/annotation => types}/SQLUserDefinedType.java (93%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst => }/types/dataTypes.scala (76%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst => }/types/decimal/Decimal.scala (99%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/{catalyst => }/types/package.scala (96%) rename sql/{core/src/test/scala/org/apache/spark/sql => catalyst/src/test/scala/org/apache/spark/sql/types}/DataTypeSuite.scala (98%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/{catalyst => }/types/decimal/DecimalSuite.scala (99%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 081a574beea5..ad6fed178fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -21,8 +21,9 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, Param, ParamMap} -import org.apache.spark.sql.{SchemaRDD, StructType} +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.types.StructType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 23fbd228d01c..1331b9124045 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions.ScalaUdf -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 85b8899636ca..8c570812f831 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 0b0504e036ec..12473cb2b571 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} +import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index e0bfb1e484a2..0956062643f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.types.DataType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 896a6b83b67b..72825f6e0218 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 9352f40f372d..e622a5cf9e6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.{DataType, StringType, ArrayType} +import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 194b9bfd9a9e..08fe99176424 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{SchemaRDD, StructType} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index bf1faa25ef0e..adbd8266ed6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,9 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 51e8bd4cf641..f6f9f491f4ce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,18 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + "removeAndGetProcessor") + ) ++ Seq( + // SPARK-5123 (SparkSQL data type change) - alpha component only + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.HashingTF.outputDataType"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.outputDataType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.validateInputType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 46a54c681840..b2c546da21c7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -254,10 +254,10 @@ object SQL { |import org.apache.spark.sql.catalyst.expressions._ |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ - |import org.apache.spark.sql.catalyst.types._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.types._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, cleanupCommands in console := "sparkContext.stop()" ) @@ -284,11 +284,11 @@ object Hive { |import org.apache.spark.sql.catalyst.expressions._ |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ - |import org.apache.spark.sql.catalyst.types._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ + |import org.apache.spark.sql.types._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce diff --git a/sql/README.md b/sql/README.md index 8d2f3cf4283e..d058a6b011d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -34,11 +34,11 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.TestHive._ +import org.apache.spark.sql.types._ Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2cf241de61f7..d169da691d79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 5d974df98b69..d19563e95cc7 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * A very simple SQL parser. Based loosely on: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c009cc1e1e85..bd00ff22ba80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.types.StructType -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.IntegerType /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 242f28f67029..15353361d97c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b2262e5e6efb..bdac7504ed02 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /** * A collection of implicit conversions that create a DSL for constructing catalyst data structures. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index fa80b07f8e6b..76a9f08dea85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4ede0b4821fe..00961f09916b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,8 +22,8 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ac5b02c2e6ae..cf14992ef835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} -import org.apache.spark.sql.catalyst.util.Metadata +import org.apache.spark.sql.types._ abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 851db95b9177..b2c6d3029031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.catalyst.types.DoubleType +import org.apache.spark.sql.types.DoubleType case object Rand extends LeafExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 463f3667fc44..dcda53bb717a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.types.NativeType object Row { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 18c96da2f87f..8a36c6810790 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.types.DataType -import org.apache.spark.util.ClosureCleaner +import org.apache.spark.sql.types.DataType /** * User-defined function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 570379c533e1..37d9f0ed5c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * A parent class for mutable container objects that are reused when the values are changed, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1a4ac06c7a79..8328278544a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.language.dynamics -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.types.DataType /** * The data type representing [[DynamicRow]] values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5ea9868e9e84..735b7488fdcb 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.util.collection.OpenHashSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 168a963e29c9..574907f566c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 90c81b2631e5..a5d642339129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types.decimal.Decimal import scala.language.existentials import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ // These classes are here to avoid issues with serialization and integration with quasiquotes. class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] @@ -541,11 +541,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin childEval.code ++ q""" var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal = + var $primitiveTerm: org.apache.spark.sql.types.decimal.Decimal = ${defaultPrimitive(DecimalType())} if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal() + $primitiveTerm = new org.apache.spark.sql.types.decimal.Decimal() $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) $nullTerm = $primitiveTerm == null } @@ -627,7 +627,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)" + case DecimalType() => q"org.apache.spark.sql.types.decimal.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 094ff1455228..0db29eb404bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{StringType, NumericType} +import org.apache.spark.sql.types.{StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2ff61169a17d..cc97cb4f50b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 9aec601886ef..1bc34f71441f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index d1eab2eb4ed5..e54cfa144a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType} +import org.apache.spark.sql.types.decimal.Decimal +import org.apache.spark.sql.types.{DecimalType, LongType, DoubleType, DataType} /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ case class UnscaledValue(child: Expression) extends UnaryExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index ab0701fd9a80..43b6482c0171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * An expression that produces zero or more rows given a single input row. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 94e1d37c1c3a..8ee4bbd8caa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a3c300b5d90e..3035d934ff9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.util.Metadata +import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index cb5ff6795986..c84cc95520a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.sql.types.BooleanType object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3d4c4a8853c1..3a5bdca1f07c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f6349767764a..f85ee0a9bb6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -23,7 +23,7 @@ import scala.collection.IndexedSeqOptimized import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} trait StringRegexExpression { self: BinaryExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index cd3137980ca4..17b4f9c23a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index a38079ced34b..105cdf52500c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -27,6 +27,6 @@ package object catalyst { * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for * 2.10.* builds. See SI-6240 for more details. */ - protected[catalyst] object ScalaReflectionLock + protected[sql] object ScalaReflectionLock } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index dcbbb62c0aca..619f42859cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.{ArrayType, DataType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType with Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ed578e081be7..65ae066e4b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.catalyst.trees /** @@ -191,14 +191,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { case (Nil, _) => expression case (requestedField :: rest, StructType(fields)) => val actualField = fields.filter(f => resolver(f.name, requestedField)) - actualField match { - case Seq() => - sys.error( - s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") - case Seq(singleMatch) => - resolveNesting(rest, GetField(expression, singleMatch.name), resolver) - case multipleMatches => - sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}") + if (actualField.length == 0) { + sys.error( + s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") + } else if (actualField.length == 1) { + resolveNesting(rest, GetField(expression, actualField(0).name), resolver) + } else { + sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}") } case (_, dt) => sys.error(s"Can't access nested field in type $dt") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0b9f01cbae9e..1483beacc908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { def output = projectList.map(_.toAttribute) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ccb0df113c06..3c3d7a311906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.types.IntegerType /** * Specifies how tuples that share common expressions will be distributed when a query is executed diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala new file mode 100644 index 000000000000..2a8914cde248 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.text.SimpleDateFormat + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.decimal.Decimal + + +protected[sql] object DataTypeConversions { + + def stringToTime(s: String): java.util.Date = { + if (!s.contains('T')) { + // JDBC escape string + if (s.contains(' ')) { + java.sql.Timestamp.valueOf(s) + } else { + java.sql.Date.valueOf(s) + } + } else if (s.endsWith("Z")) { + // this is zero timezone of ISO8601 + stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") + } else if (s.indexOf("GMT") == -1) { + // timezone with ISO8601 + val inset = "+00.00".length + val s0 = s.substring(0, s.length - inset) + val s1 = s.substring(s.length - inset, s.length) + if (s0.substring(s0.lastIndexOf(':')).contains('.')) { + stringToTime(s0 + "GMT" + s1) + } else { + stringToTime(s0 + ".0GMT" + s1) + } + } else { + // ISO8601 with GMT insert + val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) + ISO8601GMT.parse(s) + } + } + + /** Converts Java objects to catalyst rows / types */ + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (other, _) => other + } + + /** Converts Java objects to catalyst rows / types */ + def convertCatalystToJava(a: Any): Any = a match { + case d: scala.math.BigDecimal => d.underlying() + case other => other + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java similarity index 81% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java index c69bbd5736a5..e457542c647e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypes.java @@ -15,77 +15,74 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package org.apache.spark.sql.types; import java.util.*; /** - * The base type of all Spark SQL data types. - * * To get/create specific data type, users should use singleton objects and factory methods * provided by this class. */ -public abstract class DataType { - +public class DataTypes { /** * Gets the StringType object. */ - public static final StringType StringType = new StringType(); + public static final DataType StringType = StringType$.MODULE$; /** * Gets the BinaryType object. */ - public static final BinaryType BinaryType = new BinaryType(); + public static final DataType BinaryType = BinaryType$.MODULE$; /** * Gets the BooleanType object. */ - public static final BooleanType BooleanType = new BooleanType(); + public static final DataType BooleanType = BooleanType$.MODULE$; /** * Gets the DateType object. */ - public static final DateType DateType = new DateType(); + public static final DataType DateType = DateType$.MODULE$; /** * Gets the TimestampType object. */ - public static final TimestampType TimestampType = new TimestampType(); + public static final DataType TimestampType = TimestampType$.MODULE$; /** * Gets the DoubleType object. */ - public static final DoubleType DoubleType = new DoubleType(); + public static final DataType DoubleType = DoubleType$.MODULE$; /** * Gets the FloatType object. */ - public static final FloatType FloatType = new FloatType(); + public static final DataType FloatType = FloatType$.MODULE$; /** * Gets the ByteType object. */ - public static final ByteType ByteType = new ByteType(); + public static final DataType ByteType = ByteType$.MODULE$; /** * Gets the IntegerType object. */ - public static final IntegerType IntegerType = new IntegerType(); + public static final DataType IntegerType = IntegerType$.MODULE$; /** * Gets the LongType object. */ - public static final LongType LongType = new LongType(); + public static final DataType LongType = LongType$.MODULE$; /** * Gets the ShortType object. */ - public static final ShortType ShortType = new ShortType(); + public static final DataType ShortType = ShortType$.MODULE$; /** * Gets the NullType object. */ - public static final NullType NullType = new NullType(); + public static final DataType NullType = NullType$.MODULE$; /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}). @@ -95,7 +92,6 @@ public static ArrayType createArrayType(DataType elementType) { if (elementType == null) { throw new IllegalArgumentException("elementType should not be null."); } - return new ArrayType(elementType, true); } @@ -107,10 +103,17 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu if (elementType == null) { throw new IllegalArgumentException("elementType should not be null."); } - return new ArrayType(elementType, containsNull); } + public static DecimalType createDecimalType(int precision, int scale) { + return DecimalType$.MODULE$.apply(precision, scale); + } + + public static DecimalType createDecimalType() { + return DecimalType$.MODULE$.Unlimited(); + } + /** * Creates a MapType by specifying the data type of keys ({@code keyType}) and values * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. @@ -122,7 +125,6 @@ public static MapType createMapType(DataType keyType, DataType valueType) { if (valueType == null) { throw new IllegalArgumentException("valueType should not be null."); } - return new MapType(keyType, valueType, true); } @@ -141,7 +143,6 @@ public static MapType createMapType( if (valueType == null) { throw new IllegalArgumentException("valueType should not be null."); } - return new MapType(keyType, valueType, valueContainsNull); } @@ -163,7 +164,6 @@ public static StructField createStructField( if (metadata == null) { throw new IllegalArgumentException("metadata should not be null."); } - return new StructField(name, dataType, nullable, metadata); } @@ -191,18 +191,18 @@ public static StructType createStructType(StructField[] fields) { throw new IllegalArgumentException("fields should not be null."); } Set distinctNames = new HashSet(); - for (StructField field: fields) { + for (StructField field : fields) { if (field == null) { throw new IllegalArgumentException( "fields should not contain any null."); } - distinctNames.add(field.getName()); + distinctNames.add(field.name()); } if (distinctNames.size() != fields.length) { throw new IllegalArgumentException("fields should have distinct names."); } - return new StructType(fields); + return StructType$.MODULE$.apply(fields); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 8172733e94dd..e50e9761431f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -15,24 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.util +package org.apache.spark.sql.types import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.DeveloperApi + + /** + * :: DeveloperApi :: + * * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and * Array[Metadata]. JSON is used for serialization. * * The default constructor is private. User should use either [[MetadataBuilder]] or - * [[Metadata$#fromJson]] to create Metadata instances. + * [[Metadata.fromJson()]] to create Metadata instances. * * @param map an immutable map that stores the data */ -sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { +@DeveloperApi +sealed class Metadata private[types] (private[types] val map: Map[String, Any]) + extends Serializable { /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) @@ -201,8 +208,11 @@ object Metadata { } /** + * :: DeveloperApi :: + * * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. */ +@DeveloperApi class MetadataBuilder { private val map: mutable.Map[String, Any] = mutable.Map.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index e966aeea1cb2..a64d2bb7cde3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.annotation; +package org.apache.spark.sql.types; import java.lang.annotation.*; import org.apache.spark.annotation.DeveloperApi; -import org.apache.spark.sql.catalyst.types.UserDefinedType; /** * ::DeveloperApi:: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala similarity index 76% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 892b7e1a97c8..fa0a355ebc9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.types +package org.apache.spark.sql.types import java.sql.{Date, Timestamp} -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers @@ -31,11 +31,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} -import org.apache.spark.sql.catalyst.types.decimal._ -import org.apache.spark.sql.catalyst.util.Metadata +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.types.decimal._ import org.apache.spark.util.Utils + object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -140,7 +140,7 @@ object DataType { protected lazy val structType: Parser[DataType] = "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) + case fields => StructType(fields) } protected lazy val dataType: Parser[DataType] = @@ -181,7 +181,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) @@ -200,6 +200,15 @@ object DataType { } } + +/** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ +@DeveloperApi abstract class DataType { /** Matches any expression that evaluates to this DataType */ def unapply(a: Expression): Boolean = a match { @@ -218,8 +227,18 @@ abstract class DataType { def prettyJson: String = pretty(render(jsonValue)) } + +/** + * :: DeveloperApi :: + * + * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @group dataType + */ +@DeveloperApi case object NullType extends DataType + object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -237,10 +256,12 @@ object NativeType { StringType -> 4096) } + trait PrimitiveType extends DataType { override def isPrimitive = true } + object PrimitiveType { private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap @@ -267,12 +288,31 @@ abstract class NativeType extends DataType { } } + +/** + * :: DeveloperApi :: + * + * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @group dataType + */ +@DeveloperApi case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] } + +/** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * Please use the singleton [[DataTypes.BinaryType]]. + * + * @group dataType + */ +@DeveloperApi case object BinaryType extends NativeType with PrimitiveType { private[sql] type JvmType = Array[Byte] @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -287,12 +327,31 @@ case object BinaryType extends NativeType with PrimitiveType { } } + +/** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + *@group dataType + */ +@DeveloperApi case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] } + +/** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values. + * Please use the singleton [[DataTypes.TimestampType]]. + * + * @group dataType + */ +@DeveloperApi case object TimestampType extends NativeType { private[sql] type JvmType = Timestamp @@ -303,6 +362,16 @@ case object TimestampType extends NativeType { } } + +/** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Date` values. + * Please use the singleton [[DataTypes.DateType]]. + * + * @group dataType + */ +@DeveloperApi case object DateType extends NativeType { private[sql] type JvmType = Date @@ -313,6 +382,7 @@ case object DateType extends NativeType { } } + abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a @@ -322,10 +392,12 @@ abstract class NumericType extends NativeType with PrimitiveType { private[sql] val numeric: Numeric[JvmType] } + object NumericType { def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } + /** Matcher for any expressions that evaluate to [[IntegralType]]s */ object IntegralType { def unapply(a: Expression): Boolean = a match { @@ -334,10 +406,20 @@ object IntegralType { } } -abstract class IntegralType extends NumericType { + +sealed abstract class IntegralType extends NumericType { private[sql] val integral: Integral[JvmType] } + +/** + * :: DeveloperApi :: + * + * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @group dataType + */ +@DeveloperApi case object LongType extends IntegralType { private[sql] type JvmType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -346,6 +428,15 @@ case object LongType extends IntegralType { private[sql] val ordering = implicitly[Ordering[JvmType]] } + +/** + * :: DeveloperApi :: + * + * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @group dataType + */ +@DeveloperApi case object IntegerType extends IntegralType { private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -354,6 +445,15 @@ case object IntegerType extends IntegralType { private[sql] val ordering = implicitly[Ordering[JvmType]] } + +/** + * :: DeveloperApi :: + * + * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @group dataType + */ +@DeveloperApi case object ShortType extends IntegralType { private[sql] type JvmType = Short @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -362,6 +462,15 @@ case object ShortType extends IntegralType { private[sql] val ordering = implicitly[Ordering[JvmType]] } + +/** + * :: DeveloperApi :: + * + * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @group dataType + */ +@DeveloperApi case object ByteType extends IntegralType { private[sql] type JvmType = Byte @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -370,6 +479,7 @@ case object ByteType extends IntegralType { private[sql] val ordering = implicitly[Ordering[JvmType]] } + /** Matcher for any expressions that evaluate to [[FractionalType]]s */ object FractionalType { def unapply(a: Expression): Boolean = a match { @@ -378,15 +488,28 @@ object FractionalType { } } -abstract class FractionalType extends NumericType { + +sealed abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } + /** Precision parameters for a Decimal */ case class PrecisionInfo(precision: Int, scale: Int) -/** A Decimal that might have fixed precision and scale, or unlimited values for these */ + +/** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * A Decimal that might have fixed precision and scale, or unlimited values for these. + * + * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @group dataType + */ +@DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -395,6 +518,10 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT private[sql] val ordering = Decimal.DecimalIsFractional private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + + def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + override def typeName: String = precisionInfo match { case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" case None => "decimal" @@ -406,6 +533,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT } } + /** Extra factory methods and pattern matchers for Decimals */ object DecimalType { val Unlimited: DecimalType = DecimalType(None) @@ -437,6 +565,15 @@ object DecimalType { } } + +/** + * :: DeveloperApi :: + * + * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @group dataType + */ +@DeveloperApi case object DoubleType extends FractionalType { private[sql] type JvmType = Double @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -446,6 +583,15 @@ case object DoubleType extends FractionalType { private[sql] val asIntegral = DoubleAsIfIntegral } + +/** + * :: DeveloperApi :: + * + * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @group dataType + */ +@DeveloperApi case object FloatType extends FractionalType { private[sql] type JvmType = Float @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } @@ -455,18 +601,31 @@ case object FloatType extends FractionalType { private[sql] val asIntegral = FloatAsIfIntegral } + object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) } + /** + * :: DeveloperApi :: + * * The data type for collections of multiple values. * Internally these are represented as columns that contain a ``scala.collection.Seq``. * + * Please use [[DataTypes.createArrayType()]] to create a specific instance. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values + * + * @group dataType */ +@DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( @@ -480,8 +639,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("containsNull" -> containsNull) } + /** * A field inside a StructType. + * * @param name The name of this field. * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. @@ -510,19 +671,92 @@ case class StructField( } } + object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + + def apply(fields: java.util.List[StructField]): StructType = { + StructType(fields.toArray.asInstanceOf[Array[StructField]]) + } } -case class StructType(fields: Seq[StructField]) extends DataType { - /** - * Returns all field names in a [[Seq]]. - */ - lazy val fieldNames: Seq[String] = fields.map(_.name) +/** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Any names without matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[org.apache.spark.sql.catalyst.expressions.Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ +@DeveloperApi +case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + + /** Returns all field names in an array. */ + def fieldNames: Array[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. @@ -532,8 +766,8 @@ case class StructType(fields: Seq[StructField]) extends DataType { } /** - * Returns a [[StructType]] containing [[StructField]]s of the given names. - * Those names which do not have matching fields will be ignored. + * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the + * original order of fields. Those names which do not have matching fields will be ignored. */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet @@ -545,8 +779,8 @@ case class StructType(fields: Seq[StructField]) extends DataType { StructType(fields.filter(f => names.contains(f.name))) } - protected[sql] def toAttributes = - fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + protected[sql] def toAttributes: Seq[AttributeReference] = + map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) def treeString: String = { val builder = new StringBuilder @@ -565,23 +799,38 @@ case class StructType(fields: Seq[StructField]) extends DataType { override private[sql] def jsonValue = ("type" -> typeName) ~ - ("fields" -> fields.map(_.jsonValue)) + ("fields" -> map(_.jsonValue)) + + override def apply(fieldIndex: Int): StructField = fields(fieldIndex) + + override def length: Int = fields.length + + override def iterator: Iterator[StructField] = fields.iterator } + object MapType { /** * Construct a [[MapType]] object with the given key type and value type. * The `valueContainsNull` is true. */ def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, true) + MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) } + /** + * :: DeveloperApi :: + * * The data type for Maps. Keys in a map are not allowed to have `null` values. + * + * Please use [[DataTypes.createMapType()]] to create a specific instance. + * * @param keyType The data type of map keys. * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. + * + * @group dataType */ case class MapType( keyType: DataType, @@ -602,6 +851,7 @@ case class MapType( ("valueContainsNull" -> valueContainsNull) } + /** * ::DeveloperApi:: * The data type for User Defined Types (UDTs). @@ -611,7 +861,7 @@ case class MapType( * a SchemaRDD which has class X in the schema. * * For SparkSQL to recognize UDTs, the UDT must be annotated with - * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * [[SQLUserDefinedType]]. * * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala index 708362acf32d..c7864d1ae9e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.types.decimal +package org.apache.spark.sql.types.decimal import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala index de24449590f9..346a51ea10c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala @@ -15,7 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst +package org.apache.spark.sql + /** * Contains a type system for attributes produced by relations, including complex types like * structs, arrays and maps. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7be24bea7d5a..117725df32e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ case class PrimitiveData( intField: Int, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f430057ef719..3aea337460d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bbbeb4f2e4fe..bc2ec754d586 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfter, FunSuite} class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index dfa2d958c0fa..f5a502b43f80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends FunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 4ba7d87ba8c5..8552448b8d10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,16 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.scalactic.TripleEqualsSupport.Spread - -import org.apache.spark.sql.catalyst.types._ - -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal + class ExpressionEvaluationSuite extends FunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 0a27cce33748..9fdf3efa02bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 017b180c574b..da912ab38217 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 036fd3fa1d6a..cdb843f95970 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{StringType, NullType} +import org.apache.spark.sql.types.{StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { def children = optKey.toSeq diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index f005b7df2104..d7d60efee50f 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.scalatest.FunSuite +import org.apache.spark.sql.types.{MetadataBuilder, Metadata} + class MetadataSuite extends FunSuite { val baseMetadata = new MetadataBuilder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index e9740d913cf5..892195f46ea2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.types import org.scalatest.FunSuite diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala similarity index 99% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index e32f1ac38213..813377df0013 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.types.decimal +package org.apache.spark.sql.types.decimal import org.scalatest.{PrivateMethodTester, FunSuite} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java deleted file mode 100644 index b73a371e9300..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/ArrayType.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing Lists. - * An ArrayType object comprises two fields, {@code DataType elementType} and - * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of - * array elements. The field of {@code containsNull} is used to specify if the array has - * {@code null} values. - * - * To create an {@link ArrayType}, - * {@link DataType#createArrayType(DataType)} or - * {@link DataType#createArrayType(DataType, boolean)} - * should be used. - */ -public class ArrayType extends DataType { - private DataType elementType; - private boolean containsNull; - - protected ArrayType(DataType elementType, boolean containsNull) { - this.elementType = elementType; - this.containsNull = containsNull; - } - - public DataType getElementType() { - return elementType; - } - - public boolean isContainsNull() { - return containsNull; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - ArrayType arrayType = (ArrayType) o; - - if (containsNull != arrayType.containsNull) return false; - if (!elementType.equals(arrayType.elementType)) return false; - - return true; - } - - @Override - public int hashCode() { - int result = elementType.hashCode(); - result = 31 * result + (containsNull ? 1 : 0); - return result; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java deleted file mode 100644 index 7daad60f62a0..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/BinaryType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing byte[] values. - * - * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}. - */ -public class BinaryType extends DataType { - protected BinaryType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java deleted file mode 100644 index 5a1f52725631..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing boolean and Boolean values. - * - * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. - */ -public class BooleanType extends DataType { - protected BooleanType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java deleted file mode 100644 index e5cdf06b21bb..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/ByteType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing byte and Byte values. - * - * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}. - */ -public class ByteType extends DataType { - protected ByteType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java deleted file mode 100644 index 6677793baa36..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing java.sql.Date values. - * - * {@code DateType} is represented by the singleton object {@link DataType#DateType}. - */ -public class DateType extends DataType { - protected DateType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java deleted file mode 100644 index 60752451ecfc..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing java.math.BigDecimal values. - */ -public class DecimalType extends DataType { - private boolean hasPrecisionInfo; - private int precision; - private int scale; - - public DecimalType(int precision, int scale) { - this.hasPrecisionInfo = true; - this.precision = precision; - this.scale = scale; - } - - public DecimalType() { - this.hasPrecisionInfo = false; - this.precision = -1; - this.scale = -1; - } - - public boolean isUnlimited() { - return !hasPrecisionInfo; - } - - public boolean isFixed() { - return hasPrecisionInfo; - } - - /** Return the precision, or -1 if no precision is set */ - public int getPrecision() { - return precision; - } - - /** Return the scale, or -1 if no precision is set */ - public int getScale() { - return scale; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - DecimalType that = (DecimalType) o; - - if (hasPrecisionInfo != that.hasPrecisionInfo) return false; - if (precision != that.precision) return false; - if (scale != that.scale) return false; - - return true; - } - - @Override - public int hashCode() { - int result = (hasPrecisionInfo ? 1 : 0); - result = 31 * result + precision; - result = 31 * result + scale; - return result; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java deleted file mode 100644 index f0060d0bcf9f..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DoubleType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing double and Double values. - * - * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}. - */ -public class DoubleType extends DataType { - protected DoubleType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java deleted file mode 100644 index 4a6a37f69176..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/FloatType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing float and Float values. - * - * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}. - */ -public class FloatType extends DataType { - protected FloatType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java deleted file mode 100644 index bfd70490bbbb..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/IntegerType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing int and Integer values. - * - * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}. - */ -public class IntegerType extends DataType { - protected IntegerType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java deleted file mode 100644 index af13a46eb165..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/LongType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing long and Long values. - * - * {@code LongType} is represented by the singleton object {@link DataType#LongType}. - */ -public class LongType extends DataType { - protected LongType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java deleted file mode 100644 index 063e6b34abc4..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/MapType.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing Maps. A MapType object comprises two fields, - * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}. - * The field of {@code keyType} is used to specify the type of keys in the map. - * The field of {@code valueType} is used to specify the type of values in the map. - * The field of {@code valueContainsNull} is used to specify if map values have - * {@code null} values. - * For values of a MapType column, keys are not allowed to have {@code null} values. - * - * To create a {@link MapType}, - * {@link DataType#createMapType(DataType, DataType)} or - * {@link DataType#createMapType(DataType, DataType, boolean)} - * should be used. - */ -public class MapType extends DataType { - private DataType keyType; - private DataType valueType; - private boolean valueContainsNull; - - protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { - this.keyType = keyType; - this.valueType = valueType; - this.valueContainsNull = valueContainsNull; - } - - public DataType getKeyType() { - return keyType; - } - - public DataType getValueType() { - return valueType; - } - - public boolean isValueContainsNull() { - return valueContainsNull; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - MapType mapType = (MapType) o; - - if (valueContainsNull != mapType.valueContainsNull) return false; - if (!keyType.equals(mapType.keyType)) return false; - if (!valueType.equals(mapType.valueType)) return false; - - return true; - } - - @Override - public int hashCode() { - int result = keyType.hashCode(); - result = 31 * result + valueType.hashCode(); - result = 31 * result + (valueContainsNull ? 1 : 0); - return result; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java deleted file mode 100644 index 0f819fb01a76..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, - * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and - * Array[Metadata]. JSON is used for serialization. - * - * The default constructor is private. User should use [[MetadataBuilder]]. - */ -class Metadata extends org.apache.spark.sql.catalyst.util.Metadata { - Metadata(scala.collection.immutable.Map map) { - super(map); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java deleted file mode 100644 index 6e6b12f0722c..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. - */ -public class MetadataBuilder extends org.apache.spark.sql.catalyst.util.MetadataBuilder { - @Override - public Metadata build() { - return new Metadata(getMap()); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java deleted file mode 100644 index 6d5ecdf46e55..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing null and NULL values. - * - * {@code NullType} is represented by the singleton object {@link DataType#NullType}. - */ -public class NullType extends DataType { - protected NullType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java deleted file mode 100644 index 7d7604b4e3d2..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/ShortType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing short and Short values. - * - * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}. - */ -public class ShortType extends DataType { - protected ShortType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java deleted file mode 100644 index f4ba0c07c9c6..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StringType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing String values. - * - * {@code StringType} is represented by the singleton object {@link DataType#StringType}. - */ -public class StringType extends DataType { - protected StringType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java deleted file mode 100644 index 7c60d492bcdf..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.util.Map; - -/** - * A StructField object represents a field in a StructType object. - * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, - * and {@code boolean nullable}. The field of {@code name} is the name of a StructField. - * The field of {@code dataType} specifies the data type of a StructField. - * The field of {@code nullable} specifies if values of a StructField can contain {@code null} - * values. - * The field of {@code metadata} provides extra information of the StructField. - * - * To create a {@link StructField}, - * {@link DataType#createStructField(String, DataType, boolean, Metadata)} - * should be used. - */ -public class StructField { - private String name; - private DataType dataType; - private boolean nullable; - private Metadata metadata; - - protected StructField( - String name, - DataType dataType, - boolean nullable, - Metadata metadata) { - this.name = name; - this.dataType = dataType; - this.nullable = nullable; - this.metadata = metadata; - } - - public String getName() { - return name; - } - - public DataType getDataType() { - return dataType; - } - - public boolean isNullable() { - return nullable; - } - - public Metadata getMetadata() { - return metadata; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StructField that = (StructField) o; - - if (nullable != that.nullable) return false; - if (!dataType.equals(that.dataType)) return false; - if (!name.equals(that.name)) return false; - if (!metadata.equals(that.metadata)) return false; - - return true; - } - - @Override - public int hashCode() { - int result = name.hashCode(); - result = 31 * result + dataType.hashCode(); - result = 31 * result + (nullable ? 1 : 0); - result = 31 * result + metadata.hashCode(); - return result; - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java deleted file mode 100644 index a4b501efd9a1..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructType.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.util.Arrays; - -/** - * The data type representing Rows. - * A StructType object comprises an array of StructFields. - * - * To create an {@link StructType}, - * {@link DataType#createStructType(java.util.List)} or - * {@link DataType#createStructType(StructField[])} - * should be used. - */ -public class StructType extends DataType { - private StructField[] fields; - - protected StructType(StructField[] fields) { - this.fields = fields; - } - - public StructField[] getFields() { - return fields; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StructType that = (StructType) o; - - if (!Arrays.equals(fields, that.fields)) return false; - - return true; - } - - @Override - public int hashCode() { - return Arrays.hashCode(fields); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java deleted file mode 100644 index 06d44c731cdf..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/TimestampType.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -/** - * The data type representing java.sql.Timestamp values. - * - * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}. - */ -public class TimestampType extends DataType { - protected TimestampType() {} -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java deleted file mode 100644 index f0d079d25b5d..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.io.Serializable; - -import org.apache.spark.annotation.DeveloperApi; - -/** - * ::DeveloperApi:: - * The data type representing User-Defined Types (UDTs). - * UDTs may use any other DataType for an underlying representation. - */ -@DeveloperApi -public abstract class UserDefinedType extends DataType implements Serializable { - - protected UserDefinedType() { } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - @SuppressWarnings("unchecked") - UserDefinedType that = (UserDefinedType) o; - return this.sqlType().equals(that.sqlType()); - } - - /** Underlying storage type for this UDT */ - public abstract DataType sqlType(); - - /** Convert the user type to a SQL datum */ - public abstract Object serialize(Object obj); - - /** Convert a SQL datum to the user type */ - public abstract UserType deserialize(Object datum); - - /** Class object for the UserType */ - public abstract Class userClass(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d8efce0cb43e..d9f3b3a53f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,7 +24,6 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration - import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD @@ -32,14 +31,14 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} +import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.UserDefinedType -import org.apache.spark.sql.execution.{SparkStrategies, _} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} +import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.types._ /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 80787b61ce1b..686bcdfbb4ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList} - +import java.util.{List => JList} import scala.collection.JavaConversions._ @@ -37,8 +36,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{BooleanType, StructType} import org.apache.spark.storage.StorageLevel /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 65358b7d4ea8..f10ee7b66feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql +import scala.util.parsing.combinator.RegexParsers + import org.apache.spark.sql.catalyst.{SqlLexical, AbstractSparkSQLParser} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand} +import org.apache.spark.sql.types.StringType -import scala.util.parsing.combinator.RegexParsers /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 7f868cd4afca..a75f55992826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,15 +23,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.{SQLContext, StructType => SStructType} -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} -import org.apache.spark.sql.types.util.DataTypeConversions -import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -126,9 +124,8 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { @DeveloperApi def applySchema(rowRDD: JavaRDD[Row], schema: StructType): JavaSchemaRDD = { val scalaRowRDD = rowRDD.rdd.map(r => r.row) - val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = - LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext) + LogicalRDD(schema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -184,10 +181,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord val appliedScalaSchema = - Option(asScalaDataType(schema)).getOrElse( + Option(schema).getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( - json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType] + json.rdd, 1.0, columnNameOfCorruptJsonRecord))) val scalaRowRDD = JsonRDD.jsonStringToRow( json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = @@ -218,43 +215,25 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val (dataType, nullable) = property.getPropertyType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - case c: Class[_] if c == classOf[java.lang.String] => - (org.apache.spark.sql.StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => - (org.apache.spark.sql.ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => - (org.apache.spark.sql.IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => - (org.apache.spark.sql.LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => - (org.apache.spark.sql.DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => - (org.apache.spark.sql.ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => - (org.apache.spark.sql.FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => - (org.apache.spark.sql.BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => - (org.apache.spark.sql.ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => - (org.apache.spark.sql.IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => - (org.apache.spark.sql.LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => - (org.apache.spark.sql.DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => - (org.apache.spark.sql.ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => - (org.apache.spark.sql.FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => - (org.apache.spark.sql.BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => - (org.apache.spark.sql.DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => - (org.apache.spark.sql.DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => - (org.apache.spark.sql.TimestampType, true) + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 5b9c612487ac..9e10e532fb01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql.api.java import java.util.{List => JList} import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaRDDLike} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.sql.types.util.DataTypeConversions +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import DataTypeConversions._ -import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel /** @@ -59,8 +58,7 @@ class JavaSchemaRDD( override def toString: String = baseSchemaRDD.toString /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ - def schema: StructType = - asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + def schema: StructType = baseSchemaRDD.schema.asInstanceOf[StructType] // ======================================================================= // Base RDD functions that do NOT change schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 401798e317e9..207e2805fffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.catalyst.types.decimal.Decimal - import scala.annotation.varargs import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala index 158f26e3d445..4186c274515a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.api.java import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} -import org.apache.spark.sql.types.util.DataTypeConversions._ +import org.apache.spark.sql.types.DataType /** * A collection of functions that allow Java users to register UDFs. In order to handle functions @@ -38,10 +38,9 @@ private[java] trait UDFRegistration { println(s""" |def registerFunction( | name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: DataType) = { - | val scalaType = asScalaDataType(dataType) | sqlContext.functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), scalaType, e)) + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), dataType, e)) |} """.stripMargin) } @@ -94,159 +93,159 @@ private[java] trait UDFRegistration { */ // scalastyle:off - def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF1[_, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), dataType, e)) } - def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF2[_, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF3[_, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF4[_, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF5[_, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF6[_, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF7[_, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF8[_, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + + def registerFunction( + name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { - val scalaType = asScalaDataType(dataType) + def registerFunction( + name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], @transient dataType: DataType) = { sqlContext.functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), dataType, e)) } - // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala deleted file mode 100644 index a7d0f4f127ec..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java - -import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType} -import org.apache.spark.sql.{DataType => ScalaDataType} -import org.apache.spark.sql.types.util.DataTypeConversions - -/** - * Scala wrapper for a Java UserDefinedType - */ -private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) - extends ScalaUserDefinedType[UserType] with Serializable { - - /** Underlying storage type for this UDT */ - val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) - - /** Convert the user type to a SQL datum */ - def serialize(obj: Any): Any = javaUDT.serialize(obj) - - /** Convert a SQL datum to the user type */ - def deserialize(datum: Any): UserType = javaUDT.deserialize(datum) - - val userClass: java.lang.Class[UserType] = javaUDT.userClass() -} - -/** - * Java wrapper for a Scala UserDefinedType - */ -private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) - extends UserDefinedType[UserType] with Serializable { - - /** Underlying storage type for this UDT */ - val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) - - /** Convert the user type to a SQL datum */ - def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object] - - /** Convert a SQL datum to the user type */ - def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum) - - val userClass: java.lang.Class[UserType] = scalaUDT.userClass -} - -private[sql] object UDTWrappers { - - def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = { - udtType match { - case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT - case _ => new JavaToScalaUDTWrapper(udtType) - } - } - - def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = { - udtType match { - case t: JavaToScalaUDTWrapper[_] => t.javaUDT - case _ => new ScalaToJavaUDTWrapper(udtType) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 538dd5b73466..91c4c105b14e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType} import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.types.{BinaryType, DataType, NativeType} /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c68dceef3b14..3a4977b836af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnBuilder._ import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.types._ private[sql] trait ColumnBuilder { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 668efe4a3b2a..391b3dae5c8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index ab66c85c4f24..fcf2faa0914c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -24,8 +24,8 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types._ /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index 27ac5f4dbdbb..7dff9deac8dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.types.NativeType private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { this: NativeColumnAccessor[T] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 628d9cec41d6..aead768ecdf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -21,8 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.types.NativeType /** * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index acb06cb5376b..879d29bcfa6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -21,8 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.types.NativeType private[sql] trait Encoder[T <: NativeType] { def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 29edcf17242c..64673248394c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -25,10 +25,11 @@ import scala.reflect.runtime.universe.runtimeMirror import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 069e95019530..20b14834bb0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{StructType, Row, SQLContext} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.types.StructType /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 7c3bf947e743..4abe26fe4afc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ case class AggregateEvaluation( schema: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 84d96e612f0d..131146012eca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -29,7 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0652d2ff7c9a..0cc9d049c964 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.sources.{CreateTempTableUsing, CreateTableUsing} import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.{CreateTempTableUsing, CreateTableUsing} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 61be5ed2db65..46245cd5a186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5a41399971dd..741ccb8fb8fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} -import org.apache.spark.sql.catalyst.types.decimal.Decimal - import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -33,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.{Accumulator, Logging => SparkLogging} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f5c02224c82a..1af96c28d5fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.json import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StructType + private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 00449c200704..c92ec543e293 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.types.util.DataTypeConversions - import java.io.StringWriter import scala.collection.Map @@ -34,8 +31,9 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { @@ -246,7 +244,7 @@ private[sql] object JsonRDD extends Logging { // The value associated with the key is an array. // Handle inner structs of an array. def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { - case ArrayType(StructType(Nil), containsNull) => { + case ArrayType(e: StructType, containsNull) => { // The elements of this arrays are structs. v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { element => allKeysWithValueTypes(element) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1fd8e6220f83..b75266d5aa40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -116,358 +116,9 @@ package object sql { @DeveloperApi val Row = catalyst.expressions.Row - /** - * :: DeveloperApi :: - * - * The base type of all Spark SQL data types. - * - * @group dataType - */ - @DeveloperApi - type DataType = catalyst.types.DataType - - @DeveloperApi - val DataType = catalyst.types.DataType - - /** - * :: DeveloperApi :: - * - * The data type representing `String` values - * - * @group dataType - */ - @DeveloperApi - val StringType = catalyst.types.StringType - - /** - * :: DeveloperApi :: - * - * The data type representing `Array[Byte]` values. - * - * @group dataType - */ - @DeveloperApi - val BinaryType = catalyst.types.BinaryType - - /** - * :: DeveloperApi :: - * - * The data type representing `Boolean` values. - * - *@group dataType - */ - @DeveloperApi - val BooleanType = catalyst.types.BooleanType - - /** - * :: DeveloperApi :: - * - * The data type representing `java.sql.Timestamp` values. - * - * @group dataType - */ - @DeveloperApi - val TimestampType = catalyst.types.TimestampType - - /** - * :: DeveloperApi :: - * - * The data type representing `java.sql.Date` values. - * - * @group dataType - */ - @DeveloperApi - val DateType = catalyst.types.DateType - - /** - * :: DeveloperApi :: - * - * The data type representing `scala.math.BigDecimal` values. - * - * TODO(matei): explain precision and scale - * - * @group dataType - */ - @DeveloperApi - type DecimalType = catalyst.types.DecimalType - - /** - * :: DeveloperApi :: - * - * The data type representing `scala.math.BigDecimal` values. - * - * TODO(matei): explain precision and scale - * - * @group dataType - */ - @DeveloperApi - val DecimalType = catalyst.types.DecimalType - - /** - * :: DeveloperApi :: - * - * The data type representing `Double` values. - * - * @group dataType - */ - @DeveloperApi - val DoubleType = catalyst.types.DoubleType - - /** - * :: DeveloperApi :: - * - * The data type representing `Float` values. - * - * @group dataType - */ - @DeveloperApi - val FloatType = catalyst.types.FloatType - - /** - * :: DeveloperApi :: - * - * The data type representing `Byte` values. - * - * @group dataType - */ - @DeveloperApi - val ByteType = catalyst.types.ByteType - - /** - * :: DeveloperApi :: - * - * The data type representing `Int` values. - * - * @group dataType - */ - @DeveloperApi - val IntegerType = catalyst.types.IntegerType - - /** - * :: DeveloperApi :: - * - * The data type representing `Long` values. - * - * @group dataType - */ - @DeveloperApi - val LongType = catalyst.types.LongType - - /** - * :: DeveloperApi :: - * - * The data type representing `Short` values. - * - * @group dataType - */ - @DeveloperApi - val ShortType = catalyst.types.ShortType - - /** - * :: DeveloperApi :: - * - * The data type representing `NULL` values. - * - * @group dataType - */ - @DeveloperApi - val NullType = catalyst.types.NullType - - /** - * :: DeveloperApi :: - * - * The data type for collections of multiple values. - * Internally these are represented as columns that contain a ``scala.collection.Seq``. - * - * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` values. - * - * @group dataType - */ - @DeveloperApi - type ArrayType = catalyst.types.ArrayType - - /** - * :: DeveloperApi :: - * - * An [[ArrayType]] object can be constructed with two ways, - * {{{ - * ArrayType(elementType: DataType, containsNull: Boolean) - * }}} and - * {{{ - * ArrayType(elementType: DataType) - * }}} - * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`. - * - * @group dataType - */ - @DeveloperApi - val ArrayType = catalyst.types.ArrayType - - /** - * :: DeveloperApi :: - * - * The data type representing `Map`s. A [[MapType]] object comprises three fields, - * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`. - * The field of `keyType` is used to specify the type of keys in the map. - * The field of `valueType` is used to specify the type of values in the map. - * The field of `valueContainsNull` is used to specify if values of this map has `null` values. - * For values of a MapType column, keys are not allowed to have `null` values. - * - * @group dataType - */ - @DeveloperApi - type MapType = catalyst.types.MapType - - /** - * :: DeveloperApi :: - * - * A [[MapType]] object can be constructed with two ways, - * {{{ - * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) - * }}} and - * {{{ - * MapType(keyType: DataType, valueType: DataType) - * }}} - * For `MapType(keyType: DataType, valueType: DataType)`, - * the field of `valueContainsNull` is set to `true`. - * - * @group dataType - */ - @DeveloperApi - val MapType = catalyst.types.MapType - - /** - * :: DeveloperApi :: - * - * The data type representing [[Row]]s. - * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s. - * - * @group dataType - */ - @DeveloperApi - type StructType = catalyst.types.StructType - - /** - * :: DeveloperApi :: - * - * A [[StructType]] object can be constructed by - * {{{ - * StructType(fields: Seq[StructField]) - * }}} - * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. - * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. - * If a provided name does not have a matching field, it will be ignored. For the case - * of extracting a single StructField, a `null` will be returned. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val struct = - * StructType( - * StructField("a", IntegerType, true) :: - * StructField("b", LongType, false) :: - * StructField("c", BooleanType, false) :: Nil) - * - * // Extract a single StructField. - * val singleField = struct("b") - * // singleField: StructField = StructField(b,LongType,false) - * - * // This struct does not have a field called "d". null will be returned. - * val nonExisting = struct("d") - * // nonExisting: StructField = null - * - * // Extract multiple StructFields. Field names are provided in a set. - * // A StructType object will be returned. - * val twoFields = struct(Set("b", "c")) - * // twoFields: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * - * // Those names do not have matching fields will be ignored. - * // For the case shown below, "d" will be ignored and - * // it is treated as struct(Set("b", "c")). - * val ignoreNonExisting = struct(Set("b", "c", "d")) - * // ignoreNonExisting: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * }}} - * - * A [[Row]] object is used as a value of the StructType. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val innerStruct = - * StructType( - * StructField("f1", IntegerType, true) :: - * StructField("f2", LongType, false) :: - * StructField("f3", BooleanType, false) :: Nil) - * - * val struct = StructType( - * StructField("a", innerStruct, true) :: Nil) - * - * // Create a Row with the schema defined by struct - * val row = Row(Row(1, 2, true)) - * // row: Row = [[1,2,true]] - * }}} - * - * @group dataType - */ - @DeveloperApi - val StructType = catalyst.types.StructType - - /** - * :: DeveloperApi :: - * - * A [[StructField]] object represents a field in a [[StructType]] object. - * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`, - * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of - * `dataType` specifies the data type of a `StructField`. - * The field of `nullable` specifies if values of a `StructField` can contain `null` values. - * - * @group field - */ - @DeveloperApi - type StructField = catalyst.types.StructField - - /** - * :: DeveloperApi :: - * - * A [[StructField]] object can be constructed by - * {{{ - * StructField(name: String, dataType: DataType, nullable: Boolean) - * }}} - * - * @group dataType - */ - @DeveloperApi - val StructField = catalyst.types.StructField - /** * Converts a logical plan into zero or more SparkPlans. */ @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] - - /** - * :: DeveloperApi :: - * - * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, - * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and - * Array[Metadata]. JSON is used for serialization. - * - * The default constructor is private. User should use either [[MetadataBuilder]] or - * [[Metadata$#fromJson]] to create Metadata instances. - * - * @param map an immutable map that stores the data - */ - @DeveloperApi - type Metadata = catalyst.util.Metadata - - /** - * :: DeveloperApi :: - * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. - */ - @DeveloperApi - type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 1bbb66aaa19a..7f437c40777f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.catalyst.types.decimal.Decimal - import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /** * Collection of converters of Parquet types (group and primitive types) that @@ -91,8 +90,8 @@ private[sql] object CatalystConverter { case ArrayType(elementType: DataType, true) => { new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) } - case StructType(fields: Seq[StructField]) => { - new CatalystStructConverter(fields.toArray, fieldIndex, parent) + case StructType(fields: Array[StructField]) => { + new CatalystStructConverter(fields, fieldIndex, parent) } case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { new CatalystMapConverter( @@ -436,7 +435,7 @@ private[parquet] object CatalystArrayConverter { * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * [[org.apache.spark.sql.types.ArrayType]]. * * @param elementType The type of the array elements (complex or primitive) * @param index The position of this (array) field inside its parent converter @@ -500,7 +499,7 @@ private[parquet] class CatalystArrayConverter( * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array (see * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * [[org.apache.spark.sql.types.ArrayType]]. * * @param elementType The type of the array elements (native) * @param index The position of this (array) field inside its parent converter @@ -621,7 +620,7 @@ private[parquet] class CatalystNativeArrayConverter( * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array contains null (see * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.catalyst.types.ArrayType]]. + * [[org.apache.spark.sql.types.ArrayType]]. * * @param elementType The type of the array elements (complex or primitive) * @param index The position of this (array) field inside its parent converter @@ -727,7 +726,7 @@ private[parquet] class CatalystStructConverter( * A `parquet.io.api.GroupConverter` that converts two-element groups that * match the characteristics of a map (see * [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an - * [[org.apache.spark.sql.catalyst.types.MapType]]. + * [[org.apache.spark.sql.types.MapType]]. * * @param schema * @param index diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 56e7d11b2fee..f08350878f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -29,7 +29,7 @@ import parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 9049eb5932b7..af7248fdf451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,8 +29,8 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 97447871a11e..6d8c682ccced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -36,7 +36,7 @@ import parquet.schema.Type.Repetition import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ // Implicits import scala.collection.JavaConversions._ @@ -80,7 +80,7 @@ private[parquet] object ParquetTypesConverter extends Logging { /** * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.catalyst.types.DataType]]. + * [[org.apache.spark.sql.types.DataType]]. * * We apply the following conversion rules: *
      @@ -191,7 +191,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } /** - * For a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] return + * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return * the name of the corresponding Parquet primitive type or None if the given type * is not primitive. * @@ -231,21 +231,21 @@ private[parquet] object ParquetTypesConverter extends Logging { } /** - * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into + * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into * the corresponding Parquet `Type`. * * The conversion follows the rules below: *
        *
      • Primitive types are converted into Parquet's primitive types.
      • - *
      • [[org.apache.spark.sql.catalyst.types.StructType]]s are converted + *
      • [[org.apache.spark.sql.types.StructType]]s are converted * into Parquet's `GroupType` with the corresponding field types.
      • - *
      • [[org.apache.spark.sql.catalyst.types.ArrayType]]s are converted + *
      • [[org.apache.spark.sql.types.ArrayType]]s are converted * into a 2-level nested group, where the outer group has the inner * group as sole field. The inner group has name `values` and * repetition level `REPEATED` and has the element type of * the array as schema. We use Parquet's `ConversionPatterns` for this * purpose.
      • - *
      • [[org.apache.spark.sql.catalyst.types.MapType]]s are converted + *
      • [[org.apache.spark.sql.types.MapType]]s are converted * into a nested (2-level) Parquet `GroupType` with two fields: a key * type and a value type. The nested group has repetition level * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` @@ -319,7 +319,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val fields = structFields.map { field => fromDataType(field.dataType, field.name, field.nullable, inArray = false) } - new ParquetGroupType(repetition, name, fields) + new ParquetGroupType(repetition, name, fields.toSeq) } case MapType(keyType, valueType, valueContainsNull) => { val parquetKeyType = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 55a2728a85cc..1b50afbbabcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.parquet import java.util.{List => JList} +import scala.collection.JavaConversions._ + import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil @@ -30,13 +31,11 @@ import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} - import org.apache.spark.sql.{SQLConf, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, StructField, StructType} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import scala.collection.JavaConversions._ /** * Allows creation of parquet based tables using the syntax diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala index 4d87f6817dcb..12b59ba20bb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} /** @@ -27,7 +27,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation) extends LeafNode with MultiInstanceRelation { - override val output = relation.schema.toAttributes + override val output: Seq[AttributeReference] = relation.schema.toAttributes // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any) = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index f8741e008209..4cc9641c4d9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -23,11 +23,11 @@ import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * A parser for foreign DDL commands. @@ -162,10 +162,10 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val structType: Parser[DataType] = (STRUCT ~> "<" ~> repsep(structField, ",") <~ ">" ^^ { - case fields => new StructType(fields) + case fields => StructType(fields) }) | (STRUCT ~> "<>" ^^ { - case fields => new StructType(Nil) + case fields => StructType(Nil) }) private[sql] lazy val dataType: Parser[DataType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7f5564baa00f..cd82cc6ecb61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} +import org.apache.spark.sql.types.StructType /** * ::DeveloperApi:: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index b9569e96c031..006b16fbe07b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.test import java.util import scala.collection.JavaConverters._ - -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * An example class to demonstrate UDT in Scala, Java, and Python. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala deleted file mode 100644 index d4ef51798169..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types.util - -import java.text.SimpleDateFormat - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, - MetadataBuilder => JMetaDataBuilder, UDTWrappers} -import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.types.UserDefinedType - -protected[sql] object DataTypeConversions { - - /** - * Returns the equivalent StructField in Scala for the given StructField in Java. - */ - def asJavaStructField(scalaStructField: StructField): JStructField = { - JDataType.createStructField( - scalaStructField.name, - asJavaDataType(scalaStructField.dataType), - scalaStructField.nullable, - (new JMetaDataBuilder).withMetadata(scalaStructField.metadata).build()) - } - - /** - * Returns the equivalent DataType in Java for the given DataType in Scala. - */ - def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { - case udtType: UserDefinedType[_] => - UDTWrappers.wrapAsJava(udtType) - - case StringType => JDataType.StringType - case BinaryType => JDataType.BinaryType - case BooleanType => JDataType.BooleanType - case DateType => JDataType.DateType - case TimestampType => JDataType.TimestampType - case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale) - case DecimalType.Unlimited => new JDecimalType() - case DoubleType => JDataType.DoubleType - case FloatType => JDataType.FloatType - case ByteType => JDataType.ByteType - case IntegerType => JDataType.IntegerType - case LongType => JDataType.LongType - case ShortType => JDataType.ShortType - case NullType => JDataType.NullType - - case arrayType: ArrayType => JDataType.createArrayType( - asJavaDataType(arrayType.elementType), arrayType.containsNull) - case mapType: MapType => JDataType.createMapType( - asJavaDataType(mapType.keyType), - asJavaDataType(mapType.valueType), - mapType.valueContainsNull) - case structType: StructType => JDataType.createStructType( - structType.fields.map(asJavaStructField).asJava) - } - - /** - * Returns the equivalent StructField in Scala for the given StructField in Java. - */ - def asScalaStructField(javaStructField: JStructField): StructField = { - StructField( - javaStructField.getName, - asScalaDataType(javaStructField.getDataType), - javaStructField.isNullable, - javaStructField.getMetadata) - } - - /** - * Returns the equivalent DataType in Scala for the given DataType in Java. - */ - def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { - case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => - UDTWrappers.wrapAsScala(udtType) - - case stringType: org.apache.spark.sql.api.java.StringType => - StringType - case binaryType: org.apache.spark.sql.api.java.BinaryType => - BinaryType - case booleanType: org.apache.spark.sql.api.java.BooleanType => - BooleanType - case dateType: org.apache.spark.sql.api.java.DateType => - DateType - case timestampType: org.apache.spark.sql.api.java.TimestampType => - TimestampType - case decimalType: org.apache.spark.sql.api.java.DecimalType => - if (decimalType.isFixed) { - DecimalType(decimalType.getPrecision, decimalType.getScale) - } else { - DecimalType.Unlimited - } - case doubleType: org.apache.spark.sql.api.java.DoubleType => - DoubleType - case floatType: org.apache.spark.sql.api.java.FloatType => - FloatType - case byteType: org.apache.spark.sql.api.java.ByteType => - ByteType - case integerType: org.apache.spark.sql.api.java.IntegerType => - IntegerType - case longType: org.apache.spark.sql.api.java.LongType => - LongType - case shortType: org.apache.spark.sql.api.java.ShortType => - ShortType - - case arrayType: org.apache.spark.sql.api.java.ArrayType => - ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) - case mapType: org.apache.spark.sql.api.java.MapType => - MapType( - asScalaDataType(mapType.getKeyType), - asScalaDataType(mapType.getValueType), - mapType.isValueContainsNull) - case structType: org.apache.spark.sql.api.java.StructType => - StructType(structType.getFields.map(asScalaStructField)) - } - - def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { - // JDBC escape string - if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) - } else { - java.sql.Date.valueOf(s) - } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } - } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) - } - } - - /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type - case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) - case (other, _) => other - } - - /** Converts Java objects to catalyst rows / types */ - def convertCatalystToJava(a: Any): Any = a match { - case d: scala.math.BigDecimal => d.underlying() - case other => other - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java index a9a11285def5..88017eb47d90 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -19,15 +19,12 @@ import java.io.Serializable; -import org.apache.spark.sql.api.java.UDF1; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.junit.runners.Suite; -import org.junit.runner.RunWith; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -60,7 +57,7 @@ public void udf1Test() { public Integer call(String str) throws Exception { return str.length(); } - }, DataType.IntegerType); + }, DataTypes.IntegerType); // TODO: Why do we need this cast? Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test')").first(); @@ -81,7 +78,7 @@ public void udf2Test() { public Integer call(String str1, String str2) throws Exception { return str1.length() + str2.length(); } - }, DataType.IntegerType); + }, DataTypes.IntegerType); // TODO: Why do we need this cast? Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index a04b8060cd65..de586ba63591 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.types.*; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -93,9 +94,9 @@ public Row call(Person person) throws Exception { }); List fields = new ArrayList(2); - fields.add(DataType.createStructField("name", DataType.StringType, false)); - fields.add(DataType.createStructField("age", DataType.IntegerType, false)); - StructType schema = DataType.createStructType(fields); + fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); + StructType schema = DataTypes.createStructType(fields); JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); schemaRDD.registerTempTable("people"); @@ -118,14 +119,14 @@ public void applySchemaToJSON() { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataType.createStructField("bigInteger", new DecimalType(), true)); - fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); - fields.add(DataType.createStructField("double", DataType.DoubleType, true)); - fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); - fields.add(DataType.createStructField("long", DataType.LongType, true)); - fields.add(DataType.createStructField("null", DataType.StringType, true)); - fields.add(DataType.createStructField("string", DataType.StringType, true)); - StructType expectedSchema = DataType.createStructType(fields); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); + fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true)); + fields.add(DataTypes.createStructField("long", DataTypes.LongType, true)); + fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); + fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); + StructType expectedSchema = DataTypes.createStructType(fields); List expectedResult = new ArrayList(2); expectedResult.add( Row.create( diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java deleted file mode 100644 index 8396a29c61c4..000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import java.util.List; -import java.util.ArrayList; - -import org.junit.Assert; -import org.junit.Test; - -import org.apache.spark.sql.types.util.DataTypeConversions; - -public class JavaSideDataTypeConversionSuite { - public void checkDataType(DataType javaDataType) { - org.apache.spark.sql.catalyst.types.DataType scalaDataType = - DataTypeConversions.asScalaDataType(javaDataType); - DataType actual = DataTypeConversions.asJavaDataType(scalaDataType); - Assert.assertEquals(javaDataType, actual); - } - - @Test - public void createDataTypes() { - // Simple DataTypes. - checkDataType(DataType.StringType); - checkDataType(DataType.BinaryType); - checkDataType(DataType.BooleanType); - checkDataType(DataType.DateType); - checkDataType(DataType.TimestampType); - checkDataType(new DecimalType()); - checkDataType(new DecimalType(10, 4)); - checkDataType(DataType.DoubleType); - checkDataType(DataType.FloatType); - checkDataType(DataType.ByteType); - checkDataType(DataType.IntegerType); - checkDataType(DataType.LongType); - checkDataType(DataType.ShortType); - - // Simple ArrayType. - DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true); - checkDataType(simpleJavaArrayType); - - // Simple MapType. - DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType); - checkDataType(simpleJavaMapType); - - // Simple StructType. - List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); - simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); - simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); - simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); - DataType simpleJavaStructType = DataType.createStructType(simpleFields); - checkDataType(simpleJavaStructType); - - // Complex StructType. - List complexFields = new ArrayList(); - complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true)); - complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true)); - complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true)); - complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false)); - DataType complexJavaStructType = DataType.createStructType(complexFields); - checkDataType(complexJavaStructType); - - // Complex ArrayType. - DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true); - checkDataType(complexJavaArrayType); - - // Complex MapType. - DataType complexJavaMapType = - DataType.createMapType(complexJavaStructType, complexJavaArrayType, false); - checkDataType(complexJavaMapType); - } - - @Test - public void illegalArgument() { - // ArrayType - try { - DataType.createArrayType(null, true); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - - // MapType - try { - DataType.createMapType(null, DataType.StringType); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - try { - DataType.createMapType(DataType.StringType, null); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - try { - DataType.createMapType(null, null); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - - // StructField - try { - DataType.createStructField(null, DataType.StringType, true); - } catch (IllegalArgumentException expectedException) { - } - try { - DataType.createStructField("name", null, true); - } catch (IllegalArgumentException expectedException) { - } - try { - DataType.createStructField(null, null, true); - } catch (IllegalArgumentException expectedException) { - } - - // StructType - try { - List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); - simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); - simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); - simpleFields.add(null); - DataType.createStructType(simpleFields); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - try { - List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); - simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); - simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); - DataType.createStructType(simpleFields); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index ab88f3ad10d6..efe622f8bcb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ /* Implicits */ import org.apache.spark.sql.catalyst.dsl._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 811319e0a660..f5b945f468da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.types._ class RowSuite extends FunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bc72daf0880a..cbdb3e64bb66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ /* Implicits */ import org.apache.spark.sql.TestData._ @@ -748,7 +749,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val metadata = new MetadataBuilder() .putString(docKey, docValue) .build() - val schemaWithMeta = new StructType(Seq( + val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) val personWithMeta = applySchema(person, schemaWithMeta) def validateMetadata(rdd: SchemaRDD): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index cf3a59e54590..40fb8d5779fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 1806a1dd8202..a0d54d17f5f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType -import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types._ @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 8afc3a9fb218..fdbb4282baf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.catalyst.types.decimal.Decimal - import scala.beans.BeanProperty import org.scalatest.FunSuite @@ -26,6 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.NullType // Implicits import scala.collection.JavaConversions._ @@ -78,10 +77,10 @@ class JavaSQLSuite extends FunSuite { schemaRDD.registerTempTable("people") val nullRDD = javaSqlCtx.sql("SELECT null FROM people") - val structFields = nullRDD.schema.getFields() + val structFields = nullRDD.schema.fields assert(structFields.size == 1) - assert(structFields(0).getDataType().isInstanceOf[NullType]) - assert(nullRDD.collect.head.row === Seq(null)) + assert(structFields(0).dataType === NullType) + assert(nullRDD.collect().head.row === Seq(null)) } test("all types in JavaBeans") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala deleted file mode 100644 index 62fe59dd345d..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java - -import org.scalatest.FunSuite - -import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField, StructType => SStructType} -import org.apache.spark.sql.types.util.DataTypeConversions._ - -class ScalaSideDataTypeConversionSuite extends FunSuite { - - def checkDataType(scalaDataType: SDataType) { - val javaDataType = asJavaDataType(scalaDataType) - val actual = asScalaDataType(javaDataType) - assert(scalaDataType === actual, s"Converted data type ${actual} " + - s"does not equal the expected data type ${scalaDataType}") - } - - test("convert data types") { - // Simple DataTypes. - checkDataType(org.apache.spark.sql.StringType) - checkDataType(org.apache.spark.sql.BinaryType) - checkDataType(org.apache.spark.sql.BooleanType) - checkDataType(org.apache.spark.sql.DateType) - checkDataType(org.apache.spark.sql.TimestampType) - checkDataType(org.apache.spark.sql.DecimalType.Unlimited) - checkDataType(org.apache.spark.sql.DoubleType) - checkDataType(org.apache.spark.sql.FloatType) - checkDataType(org.apache.spark.sql.ByteType) - checkDataType(org.apache.spark.sql.IntegerType) - checkDataType(org.apache.spark.sql.LongType) - checkDataType(org.apache.spark.sql.ShortType) - - // Simple ArrayType. - val simpleScalaArrayType = - org.apache.spark.sql.ArrayType(org.apache.spark.sql.StringType, true) - checkDataType(simpleScalaArrayType) - - // Simple MapType. - val simpleScalaMapType = - org.apache.spark.sql.MapType(org.apache.spark.sql.StringType, org.apache.spark.sql.LongType) - checkDataType(simpleScalaMapType) - - // Simple StructType. - val simpleScalaStructType = SStructType( - SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) :: - SStructField("b", org.apache.spark.sql.BooleanType, true) :: - SStructField("c", org.apache.spark.sql.LongType, true) :: - SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) - checkDataType(simpleScalaStructType) - - // Complex StructType. - val metadata = new MetadataBuilder() - .putString("name", "age") - .build() - val complexScalaStructType = SStructType( - SStructField("simpleArray", simpleScalaArrayType, true) :: - SStructField("simpleMap", simpleScalaMapType, true) :: - SStructField("simpleStruct", simpleScalaStructType, true) :: - SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: - SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil) - checkDataType(complexScalaStructType) - - // Complex ArrayType. - val complexScalaArrayType = - org.apache.spark.sql.ArrayType(complexScalaStructType, true) - checkDataType(complexScalaArrayType) - - // Complex MapType. - val complexScalaMapType = - org.apache.spark.sql.MapType(complexScalaStructType, complexScalaArrayType, false) - checkDataType(complexScalaMapType) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index a9f0851f8826..9be0b38e689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 3f3f35d50188..87e608a8853d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -24,9 +24,9 @@ import org.scalatest.FunSuite import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types._ class ColumnTypeSuite extends FunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index a1f21219eaf2..f941465fa3e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.{DataType, NativeType} +import org.apache.spark.sql.types.{DataType, NativeType} object ColumnarTestUtils { def makeNullRow(length: Int) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 21906e3fdcc6..f95c895587f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.types.DataType class TestNullableColumnAccessor[T <: DataType, JvmType]( buffer: ByteBuffer, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb73f3da81e2..80bd5c94570c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types._ class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 1cdb909146d5..c82d9799359c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -22,9 +22,9 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.types.NativeType class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new IntColumnStats, INT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 73f31c023334..88011631ee4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.IntegralType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends FunSuite { testIntegralDelta(new IntColumnStats, INT, IntDelta) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 4ce2552112c9..08df1db37509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.types.NativeType class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 7db723d648d8..0b18b4119268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.columnar.compression -import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.types.NativeType class TestCompressibleColumnBuilder[T <: NativeType]( override val columnStats: ColumnStats, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index b09f1ac49553..01c1ce2a6102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} -import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ - -import java.sql.{Date, Timestamp} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal +import org.apache.spark.sql.{QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { - import TestJsonData._ + import org.apache.spark.sql.json.TestJsonData._ TestJsonData test("Type promotion") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 6ac67fcafe16..973819aaa4d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -21,8 +21,6 @@ import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport @@ -32,11 +30,13 @@ import parquet.hadoop.{ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.types.DecimalType import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types.DecimalType import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 0a92336a3cb3..fe781ec05fb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -29,10 +29,10 @@ import parquet.io.api.Binary import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils case class TestRDDEntry(key: Int, value: String) @@ -911,20 +911,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(tmpdir) } - test("DataType string parser compatibility") { - val schema = StructType(List( - StructField("c1", IntegerType, false), - StructField("c2", BinaryType, false))) - - val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) - val fromJson = ParquetTypesConverter.convertFromString(schema.json) - - (fromCaseClassString, fromJson).zipped.foreach { (a, b) => - assert(a.name == b.name) - assert(a.dataType === b.dataType) - } - } - test("read/write fixed-length decimals") { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { val tempDir = getTempFilePath("parquetTest").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 34d61bf90848..64274950b868 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,7 +24,6 @@ import org.scalatest.FunSuite import parquet.schema.MessageTypeParser import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.sql.test.TestSQLContext class ParquetSchemaSuite extends FunSuite with ParquetTest { @@ -148,12 +147,20 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { """.stripMargin) test("DataType string parser compatibility") { - val schema = StructType(List( - StructField("c1", IntegerType, false), - StructField("c2", BinaryType, true))) - - val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) - val fromJson = ParquetTypesConverter.convertFromString(schema.json) + // This is the generated string from previous versions of the Spark SQL, using the following: + // val schema = StructType(List( + // StructField("c1", IntegerType, false), + // StructField("c2", BinaryType, true))) + val caseClassString = + "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + + val jsonString = + """ + |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} + """.stripMargin + + val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) + val fromJson = ParquetTypesConverter.convertFromString(jsonString) (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 939b3c0c66de..390538d35a34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.sources import scala.language.existentials import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + class FilteredScanSource extends RelationProvider { override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index fee2e22611cd..7900b3e8948d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ +import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index a1d2468b2573..382dddcdeac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources import java.sql.{Timestamp, Date} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 742acba58d77..171d707b138b 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -32,11 +32,11 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.types._ /** * A compatibility layer for interacting with Hive version 0.12.0. diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index b82156427a88..bec9d9aca30c 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -30,11 +30,11 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.types._ /** * A compatibility layer for interacting with Hive version 0.13.1. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9aeebd7e5436..bf56e60cf995 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -37,10 +37,10 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.DecimalType import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} import org.apache.spark.sql.sources.DataSourceStrategy +import org.apache.spark.sql.types._ /** * DEPRECATED: Use HiveContext instead. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index a156d6f7e285..245b847cf4cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -24,9 +24,9 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.types +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -43,7 +43,7 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.catalyst.types.decimal.Decimal + * org.apache.spark.sql.types.decimal.Decimal * Array[Byte] * java.sql.Date * java.sql.Timestamp @@ -504,7 +504,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) + java.util.Arrays.asList(fields.map(f => f.name) :_*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) } /** @@ -618,7 +619,9 @@ private[hive] trait HiveInspectors { case ArrayType(elemType, _) => getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => - getStructTypeInfo(fields.map(_.name), fields.map(_.dataType.toTypeInfo)) + getStructTypeInfo( + java.util.Arrays.asList(fields.map(_.name) :_*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 785a6a14f49f..d40f9936fd3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /* Implicit conversions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 34622b5f5787..b13ef7276bf3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.sql.Date + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.lib.Node @@ -31,10 +32,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cdff82e3d04d..6952b126cf89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConversions._ + import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.sources.CreateTableUsing -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql.types.StringType -import scala.collection.JavaConversions._ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 8ba818af5f9d..781a2e9164c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.types.StringType +import org.apache.spark.sql.types.StringType /** * :: DeveloperApi :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8bbcd6fec1f3..b56175fe7637 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -19,20 +19,18 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{BooleanType, DataType} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.types.{BooleanType, DataType} /** * :: DeveloperApi :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index e70cdeaad4c0..cf72345efa63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.StructType /** * :: DeveloperApi :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7d863f9d89da..d898b876c39f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index f90d3607915a..dc23d9a101d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -17,22 +17,21 @@ package org.apache.spark.sql.hive -import java.sql.Date import java.util +import java.sql.Date import java.util.{Locale, TimeZone} -import org.apache.hadoop.hive.serde2.io.DoubleWritable -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.scalatest.FunSuite - import org.apache.hadoop.hive.ql.udf.UDAFPercentile -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.io.DoubleWritable +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable +import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.decimal.Decimal class HiveInspectorSuite extends FunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { @@ -93,7 +92,6 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { val row = data.map(_.eval(null)) val dataTypes = data.map(_.dataType) - import scala.collection.JavaConversions._ def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) @@ -115,7 +113,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - fields.map(f => f.name), fields.map(f => toWritableInspector(f.dataType))) + java.util.Arrays.asList(fields.map(f => f.name) :_*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 041a36f1295e..fa6905f31f81 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources.DDLParser import org.apache.spark.sql.test.ExamplePointUDT +import org.apache.spark.sql.types.StructType class HiveMetastoreCatalogSuite extends FunSuite { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index fb481edc853b..7cfb875e05db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,6 +22,7 @@ import java.io.File import com.google.common.io.Files import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types._ /* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index ec9ebb4a775a..8ff833e0d60d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -25,6 +25,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.sql._ import org.apache.spark.util.Utils +import org.apache.spark.sql.types._ /* Implicits */ import org.apache.spark.sql.hive.test.TestHive._ diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 25fdf5c5f3da..a5587460fd69 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -18,8 +18,12 @@ package org.apache.spark.sql.hive import java.net.URI -import java.util.{ArrayList => JArrayList} -import java.util.Properties +import java.util.{ArrayList => JArrayList, Properties} + +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + +import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.`type`.HiveDecimal @@ -29,20 +33,16 @@ import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.decimal.Decimal case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { // for Serialization diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index e47002cb0b8c..a7121360dd35 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive import java.util.{ArrayList => JArrayList} import java.util.Properties +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable @@ -37,12 +40,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.types.DecimalType -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import scala.collection.JavaConversions._ -import scala.language.implicitConversions +import org.apache.spark.Logging +import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.decimal.Decimal /** From d5eeb35167e1ab72fab7778757163ff0aacaef2c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jan 2015 00:38:55 -0800 Subject: [PATCH 138/215] [SPARK-5167][SQL] Move Row into sql package and make it usable for Java. Mostly just moving stuff around. This should still be source compatible since we type aliased Row previously in org.apache.spark.sql.Row. Added the following APIs to Row: ```scala def getMap[K, V](i: Int): scala.collection.Map[K, V] def getJavaMap[K, V](i: Int): java.util.Map[K, V] def getSeq[T](i: Int): Seq[T] def getList[T](i: Int): java.util.List[T] def getStruct(i: Int): StructType ``` Author: Reynold Xin Closes #4030 from rxin/sql-row and squashes the following commits: 6c85c29 [Reynold Xin] Fixed style violation by adding a new line to Row.scala. 82b064a [Reynold Xin] [SPARK-5167][SQL] Move Row into sql package and make it usable for Java. --- .../java/org/apache/spark/sql/RowFactory.java | 34 +++ .../main/scala/org/apache/spark/sql/Row.scala | 240 ++++++++++++++++++ .../sql/catalyst/expressions/package.scala | 4 + .../expressions/{Row.scala => rows.scala} | 113 ++------- .../apache/spark/sql/types/dataTypes.scala | 2 +- .../org/apache/spark/sql/api/java/Row.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 83 ------ 7 files changed, 304 insertions(+), 174 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{Row.scala => rows.scala} (68%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java new file mode 100644 index 000000000000..62fcec824d09 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import org.apache.spark.sql.catalyst.expressions.GenericRow; + +/** + * A factory class used to construct {@link Row} objects. + */ +public class RowFactory { + + /** + * Create a {@link Row} from an array of values. Position i in the array becomes position i + * in the created {@link Row} object. + */ + public static Row create(Object[] values) { + return new GenericRow(values); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala new file mode 100644 index 000000000000..d7a4e014ce6a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.GenericRow + + +object Row { + /** + * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + */ + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) +} + + +/** + * Represents one row of output from a relational operator. Allows both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * + * It is invalid to use the native primitive interface to retrieve a value that is null, instead a + * user must check `isNullAt` before attempting to retrieve a value that might be null. + * + * To create a new Row, use [[RowFactory.create()]] in Java or [[Row.apply()]] in Scala. + * + * A [[Row]] object can be constructed by providing field values. Example: + * {{{ + * import org.apache.spark.sql._ + * + * // Create a Row from values. + * Row(value1, value2, value3, ...) + * // Create a Row from a Seq of values. + * Row.fromSeq(Seq(value1, value2, ...)) + * }}} + * + * A value of a row can be accessed through both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * An example of generic access by ordinal: + * {{{ + * import org.apache.spark.sql._ + * + * val row = Row(1, true, "a string", null) + * // row: Row = [1,true,a string,null] + * val firstValue = row(0) + * // firstValue: Any = 1 + * val fourthValue = row(3) + * // fourthValue: Any = null + * }}} + * + * For native primitive access, it is invalid to use the native primitive interface to retrieve + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a + * value that might be null. + * An example of native primitive access: + * {{{ + * // using the row from the previous example. + * val firstValue = row.getInt(0) + * // firstValue: Int = 1 + * val isNull = row.isNullAt(3) + * // isNull: Boolean = true + * }}} + * + * Interfaces related to native primitive access are: + * + * `isNullAt(i: Int): Boolean` + * + * `getInt(i: Int): Int` + * + * `getLong(i: Int): Long` + * + * `getDouble(i: Int): Double` + * + * `getFloat(i: Int): Float` + * + * `getBoolean(i: Int): Boolean` + * + * `getShort(i: Int): Short` + * + * `getByte(i: Int): Byte` + * + * `getString(i: Int): String` + * + * In Scala, fields in a [[Row]] object can be extracted in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + * + * @group row + */ +trait Row extends Seq[Any] with Serializable { + def apply(i: Int): Any + + /** Returns the value at position i. If the value is null, null is returned. */ + def get(i: Int): Any = apply(i) + + /** Checks whether the value at position i is null. */ + def isNullAt(i: Int): Boolean + + /** + * Returns the value at position i as a primitive int. + * Throws an exception if the type mismatches or if the value is null. + */ + def getInt(i: Int): Int + + /** + * Returns the value at position i as a primitive long. + * Throws an exception if the type mismatches or if the value is null. + */ + def getLong(i: Int): Long + + /** + * Returns the value at position i as a primitive double. + * Throws an exception if the type mismatches or if the value is null. + */ + def getDouble(i: Int): Double + + /** + * Returns the value at position i as a primitive float. + * Throws an exception if the type mismatches or if the value is null. + */ + def getFloat(i: Int): Float + + /** + * Returns the value at position i as a primitive boolean. + * Throws an exception if the type mismatches or if the value is null. + */ + def getBoolean(i: Int): Boolean + + /** + * Returns the value at position i as a primitive short. + * Throws an exception if the type mismatches or if the value is null. + */ + def getShort(i: Int): Short + + /** + * Returns the value at position i as a primitive byte. + * Throws an exception if the type mismatches or if the value is null. + */ + def getByte(i: Int): Byte + + /** + * Returns the value at position i as a String object. + * Throws an exception if the type mismatches or if the value is null. + */ + def getString(i: Int): String + + /** + * Return the value at position i of array type as a Scala Seq. + * Throws an exception if the type mismatches. + */ + def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]] + + /** + * Return the value at position i of array type as [[java.util.List]]. + * Throws an exception if the type mismatches. + */ + def getList[T](i: Int): java.util.List[T] = { + scala.collection.JavaConversions.seqAsJavaList(getSeq[T](i)) + } + + /** + * Return the value at position i of map type as a Scala Map. + * Throws an exception if the type mismatches. + */ + def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]] + + /** + * Return the value at position i of array type as a [[java.util.Map]]. + * Throws an exception if the type mismatches. + */ + def getJavaMap[K, V](i: Int): java.util.Map[K, V] = { + scala.collection.JavaConversions.mapAsJavaMap(getMap[K, V](i)) + } + + /** + * Return the value at position i of struct type as an [[Row]] object. + * Throws an exception if the type mismatches. + */ + def getStruct(i: Int): Row = getAs[Row](i) + + /** + * Returns the value at position i. + * Throws an exception if the type mismatches. + */ + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + + override def toString(): String = s"[${this.mkString(",")}]" + + /** + * Make a copy of the current [[Row]] object. + */ + def copy(): Row + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val l = length + var i = 0 + while (i < l) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 55d95991c5f1..fbc97b2e7531 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -49,6 +49,10 @@ package org.apache.spark.sql.catalyst */ package object expressions { + type Row = org.apache.spark.sql.Row + + val Row = org.apache.spark.sql.Row + /** * Converts a [[Row]] to another Row given a sequence of expression that define each column of the * new row. If the schema of the input row is specified, then the given expression will be bound diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala similarity index 68% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index dcda53bb717a..c22b8426841d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,68 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types.NativeType -object Row { - /** - * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: - * {{{ - * import org.apache.spark.sql._ - * - * val pairs = sql("SELECT key, value FROM src").rdd.map { - * case Row(key: Int, value: String) => - * key -> value - * } - * }}} - */ - def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) - - /** - * This method can be used to construct a [[Row]] with the given values. - */ - def apply(values: Any*): Row = new GenericRow(values.toArray) - - /** - * This method can be used to construct a [[Row]] from a [[Seq]] of values. - */ - def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) -} - -/** - * Represents one row of output from a relational operator. Allows both generic access by ordinal, - * which will incur boxing overhead for primitives, as well as native primitive access. - * - * It is invalid to use the native primitive interface to retrieve a value that is null, instead a - * user must check [[isNullAt]] before attempting to retrieve a value that might be null. - */ -trait Row extends Seq[Any] with Serializable { - def apply(i: Int): Any - - def isNullAt(i: Int): Boolean - - def getInt(i: Int): Int - def getLong(i: Int): Long - def getDouble(i: Int): Double - def getFloat(i: Int): Float - def getBoolean(i: Int): Boolean - def getShort(i: Int): Short - def getByte(i: Int): Byte - def getString(i: Int): String - def getAs[T](i: Int): T = apply(i).asInstanceOf[T] - - override def toString() = - s"[${this.mkString(",")}]" - - def copy(): Row - - /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean = { - var i = 0 - while (i < length) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } -} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -105,22 +43,19 @@ trait MutableRow extends Row { * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. */ object EmptyRow extends Row { - def apply(i: Int): Any = throw new UnsupportedOperationException - - def iterator = Iterator.empty - def length = 0 - def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException - - def getInt(i: Int): Int = throw new UnsupportedOperationException - def getLong(i: Int): Long = throw new UnsupportedOperationException - def getDouble(i: Int): Double = throw new UnsupportedOperationException - def getFloat(i: Int): Float = throw new UnsupportedOperationException - def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException - def getShort(i: Int): Short = throw new UnsupportedOperationException - def getByte(i: Int): Byte = throw new UnsupportedOperationException - def getString(i: Int): String = throw new UnsupportedOperationException + override def apply(i: Int): Any = throw new UnsupportedOperationException + override def iterator = Iterator.empty + override def length = 0 + override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException + override def getInt(i: Int): Int = throw new UnsupportedOperationException + override def getLong(i: Int): Long = throw new UnsupportedOperationException + override def getDouble(i: Int): Double = throw new UnsupportedOperationException + override def getFloat(i: Int): Float = throw new UnsupportedOperationException + override def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException + override def getShort(i: Int): Short = throw new UnsupportedOperationException + override def getByte(i: Int): Byte = throw new UnsupportedOperationException + override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this } @@ -135,50 +70,50 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { def this(size: Int) = this(new Array[Any](size)) - def iterator = values.iterator + override def iterator = values.iterator - def length = values.length + override def length = values.length - def apply(i: Int) = values(i) + override def apply(i: Int) = values(i) - def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int) = values(i) == null - def getInt(i: Int): Int = { + override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") values(i).asInstanceOf[Int] } - def getLong(i: Int): Long = { + override def getLong(i: Int): Long = { if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") values(i).asInstanceOf[Long] } - def getDouble(i: Int): Double = { + override def getDouble(i: Int): Double = { if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") values(i).asInstanceOf[Double] } - def getFloat(i: Int): Float = { + override def getFloat(i: Int): Float = { if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") values(i).asInstanceOf[Float] } - def getBoolean(i: Int): Boolean = { + override def getBoolean(i: Int): Boolean = { if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") values(i).asInstanceOf[Boolean] } - def getShort(i: Int): Short = { + override def getShort(i: Int): Short = { if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") values(i).asInstanceOf[Short] } - def getByte(i: Int): Byte = { + override def getByte(i: Int): Byte = { if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") values(i).asInstanceOf[Byte] } - def getString(i: Int): String = { + override def getString(i: Int): String = { if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") values(i).asInstanceOf[String] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index fa0a355ebc9b..e38ad63f2e2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -727,7 +727,7 @@ object StructType { * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) * }}} * - * A [[org.apache.spark.sql.catalyst.expressions.Row]] object is used as a value of the StructType. + * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. * Example: * {{{ * import org.apache.spark.sql._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 207e2805fffe..4faa79af2568 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions import scala.math.BigDecimal import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap -import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} +import org.apache.spark.sql.{Row => ScalaRow} /** * A result row from a Spark SQL query. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index b75266d5aa40..6dd39be80703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -33,89 +33,6 @@ import org.apache.spark.sql.execution.SparkPlan */ package object sql { - /** - * :: DeveloperApi :: - * - * Represents one row of output from a relational operator. - * @group row - */ - @DeveloperApi - type Row = catalyst.expressions.Row - - /** - * :: DeveloperApi :: - * - * A [[Row]] object can be constructed by providing field values. Example: - * {{{ - * import org.apache.spark.sql._ - * - * // Create a Row from values. - * Row(value1, value2, value3, ...) - * // Create a Row from a Seq of values. - * Row.fromSeq(Seq(value1, value2, ...)) - * }}} - * - * A value of a row can be accessed through both generic access by ordinal, - * which will incur boxing overhead for primitives, as well as native primitive access. - * An example of generic access by ordinal: - * {{{ - * import org.apache.spark.sql._ - * - * val row = Row(1, true, "a string", null) - * // row: Row = [1,true,a string,null] - * val firstValue = row(0) - * // firstValue: Any = 1 - * val fourthValue = row(3) - * // fourthValue: Any = null - * }}} - * - * For native primitive access, it is invalid to use the native primitive interface to retrieve - * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a - * value that might be null. - * An example of native primitive access: - * {{{ - * // using the row from the previous example. - * val firstValue = row.getInt(0) - * // firstValue: Int = 1 - * val isNull = row.isNullAt(3) - * // isNull: Boolean = true - * }}} - * - * Interfaces related to native primitive access are: - * - * `isNullAt(i: Int): Boolean` - * - * `getInt(i: Int): Int` - * - * `getLong(i: Int): Long` - * - * `getDouble(i: Int): Double` - * - * `getFloat(i: Int): Float` - * - * `getBoolean(i: Int): Boolean` - * - * `getShort(i: Int): Short` - * - * `getByte(i: Int): Byte` - * - * `getString(i: Int): String` - * - * Fields in a [[Row]] object can be extracted in a pattern match. Example: - * {{{ - * import org.apache.spark.sql._ - * - * val pairs = sql("SELECT key, value FROM src").rdd.map { - * case Row(key: Int, value: String) => - * key -> value - * } - * }}} - * - * @group row - */ - @DeveloperApi - val Row = catalyst.expressions.Row - /** * Converts a logical plan into zero or more SparkPlans. */ From a3f7421b42f45e39f3e53679188e4eae2ed1f208 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 14 Jan 2015 09:36:59 -0800 Subject: [PATCH 139/215] [SPARK-5248] [SQL] move sql.types.decimal.Decimal to sql.types.Decimal rxin follow up of #3732 Author: Daoyuan Wang Closes #4041 from adrian-wang/decimal and squashes the following commits: aa3d738 [Daoyuan Wang] fix auto refactor 7777a58 [Daoyuan Wang] move sql.types.decimal.Decimal to sql.types.Decimal --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 1 - .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 1 - .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 1 - .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 7 +++---- .../spark/sql/catalyst/expressions/decimalFunctions.scala | 3 +-- .../apache/spark/sql/catalyst/expressions/literals.scala | 1 - .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 - .../org/apache/spark/sql/types/DataTypeConversions.scala | 2 +- .../org/apache/spark/sql/types/{decimal => }/Decimal.scala | 4 ++-- .../main/scala/org/apache/spark/sql/types/dataTypes.scala | 1 - .../catalyst/expressions/ExpressionEvaluationSuite.scala | 1 - .../org/apache/spark/sql/types/decimal/DecimalSuite.scala | 1 + .../apache/spark/sql/execution/SparkSqlSerializer.scala | 3 ++- .../src/main/scala/org/apache/spark/sql/json/JsonRDD.scala | 1 - .../org/apache/spark/sql/parquet/ParquetConverter.scala | 1 - .../org/apache/spark/sql/parquet/ParquetTableSupport.scala | 1 - .../test/scala/org/apache/spark/sql/json/JsonSuite.scala | 1 - .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 3 +-- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 1 - .../org/apache/spark/sql/hive/HiveInspectorSuite.scala | 1 - .../src/main/scala/org/apache/spark/sql/hive/Shim12.scala | 3 +-- .../src/main/scala/org/apache/spark/sql/hive/Shim13.scala | 3 +-- 22 files changed, 13 insertions(+), 29 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/types/{decimal => }/Decimal.scala (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d169da691d79..697bacfedc62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -23,7 +23,6 @@ import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index bdac7504ed02..8bc36a238dbb 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /** * A collection of implicit conversions that create a DSL for constructing catalyst data structures. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 00961f09916b..1a2133bbbcec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,7 +23,6 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a5d642339129..4cae5c471868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} -import org.apache.spark.sql.types.decimal.Decimal import scala.language.existentials @@ -541,11 +540,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin childEval.code ++ q""" var $nullTerm = ${childEval.nullTerm} - var $primitiveTerm: org.apache.spark.sql.types.decimal.Decimal = + var $primitiveTerm: org.apache.spark.sql.types.Decimal = ${defaultPrimitive(DecimalType())} if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.decimal.Decimal() + $primitiveTerm = new org.apache.spark.sql.types.Decimal() $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) $nullTerm = $primitiveTerm == null } @@ -627,7 +626,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType() => q"org.apache.spark.sql.types.decimal.Decimal(-1)" + case DecimalType() => q"org.apache.spark.sql.types.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index e54cfa144a17..83d8c1d42bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.decimal.Decimal -import org.apache.spark.sql.types.{DecimalType, LongType, DoubleType, DataType} +import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ case class UnscaledValue(child: Expression) extends UnaryExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 8ee4bbd8caa6..c94a947fb275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 17b4f9c23a97..d4a4c35691bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala index 2a8914cde248..08bb933a2b33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.decimal.Decimal +import org.apache.spark.sql.types.Decimal protected[sql] object DataTypeConversions { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c7864d1ae9e7..3744d77c0736 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/decimal/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types.decimal +package org.apache.spark.sql.types import org.apache.spark.annotation.DeveloperApi @@ -28,7 +28,7 @@ import org.apache.spark.annotation.DeveloperApi * - Otherwise, the decimal value is longVal / (10 ** _scale) */ final class Decimal extends Ordered[Decimal] with Serializable { - import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} private var decimalVal: BigDecimal = null private var longVal: Long = 0L diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index e38ad63f2e2c..e1cbe6650aaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -32,7 +32,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.types.decimal._ import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8552448b8d10..37e64adeea85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal class ExpressionEvaluationSuite extends FunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 813377df0013..de6a2cd448c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.sql.types.Decimal import org.scalatest.{PrivateMethodTester, FunSuite} import scala.language.postfixOps diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 131146012eca..7a0249137a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import org.apache.spark.sql.types.Decimal + import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -29,7 +31,6 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index c92ec543e293..453b560ff871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 7f437c40777f..b4aed0419912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -25,7 +25,6 @@ import parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /** * Collection of converters of Parquet types (group and primitive types) that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index af7248fdf451..fd63ad814406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -30,7 +30,6 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 01c1ce2a6102..1dd85a3bb43a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal import org.apache.spark.sql.{QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 245b847cf4cd..5140d2064c5f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -43,7 +42,7 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.decimal.Decimal + * org.apache.spark.sql.types.Decimal * Array[Byte] * java.sql.Date * java.sql.Timestamp diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b13ef7276bf3..5e29e57d9358 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index dc23d9a101d1..486460725203 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.decimal.Decimal class HiveInspectorSuite extends FunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index a5587460fd69..58417a15bbed 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,8 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.types.decimal.Decimal +import org.apache.spark.sql.types.{Decimal, DecimalType} case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { // for Serialization diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index a7121360dd35..1f768ca97124 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -42,8 +42,7 @@ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.types.decimal.Decimal +import org.apache.spark.sql.types.{Decimal, DecimalType} /** From 81f72a0df2250debe8a6a0773d809d8c42eeabb9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 14 Jan 2015 09:47:30 -0800 Subject: [PATCH 140/215] [SPARK-5211][SQL]Restore HiveMetastoreTypes.toDataType jira: https://issues.apache.org/jira/browse/SPARK-5211 Author: Yin Huai Closes #4026 from yhuai/SPARK-5211 and squashes the following commits: 15ee32b [Yin Huai] Remove extra line. c6c1651 [Yin Huai] Get back HiveMetastoreTypes.toDataType. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 8 +++++++- .../apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 5 +---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d40f9936fd3b..1a49f09bd998 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -558,6 +558,12 @@ private[hive] case class MetastoreRelation } object HiveMetastoreTypes { + protected val ddlParser = new DDLParser + + def toDataType(metastoreType: String): DataType = synchronized { + ddlParser.parseType(metastoreType) + } + def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa6905f31f81..aad48ada5264 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import org.scalatest.FunSuite -import org.apache.spark.sql.sources.DDLParser import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType @@ -28,9 +27,7 @@ class HiveMetastoreCatalogSuite extends FunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" - val ddlParser = new DDLParser - - val datatype = ddlParser.parseType(metastr) + val datatype = HiveMetastoreTypes.toDataType(metastr) assert(datatype.isInstanceOf[StructType]) } From 38bdc992a1a0485ac630af500da54f0a77e133bf Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 14 Jan 2015 09:50:01 -0800 Subject: [PATCH 141/215] [SQL] some comments fix for GROUPING SETS Author: Daoyuan Wang Closes #4000 from adrian-wang/comment and squashes the following commits: 9c24fc4 [Daoyuan Wang] some comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bd00ff22ba80..7f4cc234dc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -126,10 +126,10 @@ class Analyzer(catalog: Catalog, } /* - * GROUP BY a, b, c, WITH ROLLUP + * GROUP BY a, b, c WITH ROLLUP * is equivalent to - * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( )). - * Group Count: N + 1 (N is the number of group expression) + * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ). + * Group Count: N + 1 (N is the number of group expressions) * * We need to get all of its subsets for the rule described above, the subset is * represented as the bit masks. @@ -139,12 +139,12 @@ class Analyzer(catalog: Catalog, } /* - * GROUP BY a, b, c, WITH CUBE + * GROUP BY a, b, c WITH CUBE * is equivalent to * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ). - * Group Count: 2^N (N is the number of group expression) + * Group Count: 2 ^ N (N is the number of group expressions) * - * We need to get all of its sub sets for a given GROUPBY expressions, the subset is + * We need to get all of its subsets for a given GROUPBY expression, the subsets are * represented as the bit masks. */ def bitmasks(c: Cube): Seq[Int] = { From 5840f5464bad8431810d459c97d6e4635eea175c Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 14 Jan 2015 11:03:11 -0800 Subject: [PATCH 142/215] [SPARK-2909] [MLlib] [PySpark] SparseVector in pyspark now supports indexing Slightly different than the scala code which converts the sparsevector into a densevector and then checks the index. I also hope I've added tests in the right place. Author: MechCoder Closes #4025 from MechCoder/spark-2909 and squashes the following commits: 07d0f26 [MechCoder] STY: Rename item to index f02148b [MechCoder] [SPARK-2909] [Mlib] SparseVector in pyspark now supports indexing --- python/pyspark/mllib/linalg.py | 17 +++++++++++++++++ python/pyspark/mllib/tests.py | 12 ++++++++++++ 2 files changed, 29 insertions(+) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 4f8491f43e45..7f21190ed8c2 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -510,6 +510,23 @@ def __eq__(self, other): and np.array_equal(other.indices, self.indices) and np.array_equal(other.values, self.values)) + def __getitem__(self, index): + inds = self.indices + vals = self.values + if not isinstance(index, int): + raise ValueError( + "Indices must be of type integer, got type %s" % type(index)) + if index < 0: + index += self.size + if index >= self.size or index < 0: + raise ValueError("Index %d out of bounds." % index) + + insert_index = np.searchsorted(inds, index) + row_ind = inds[insert_index] + if row_ind == index: + return vals[insert_index] + return 0. + def __ne__(self, other): return not self.__eq__(other) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1f48bc1219db..140c22b5fd4e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -120,6 +120,18 @@ def test_conversion(self): dv = DenseVector(v) self.assertTrue(dv.array.dtype == 'float64') + def test_sparse_vector_indexing(self): + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEquals(sv[0], 0.) + self.assertEquals(sv[3], 2.) + self.assertEquals(sv[1], 1.) + self.assertEquals(sv[2], 0.) + self.assertEquals(sv[-1], 2) + self.assertEquals(sv[-2], 0) + self.assertEquals(sv[-4], 0) + for ind in [4, -5, 7.8]: + self.assertRaises(ValueError, sv.__getitem__, ind) + class ListTests(PySparkTestCase): From 9d4449c4b3d0f5e60e71fef3a9b2e1c58a80b9e8 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 14 Jan 2015 11:10:29 -0800 Subject: [PATCH 143/215] [SPARK-5228][WebUI] Hide tables for "Active Jobs/Completed Jobs/Failed Jobs" when they are empty In current WebUI, tables for Active Stages, Completed Stages, Skipped Stages and Failed Stages are hidden when they are empty while tables for Active Jobs, Completed Jobs and Failed Jobs are not hidden though they are empty. This is before my patch is applied. ![2015-01-13 14 13 03](https://cloud.githubusercontent.com/assets/4736016/5730793/2b73d6f4-9b32-11e4-9a24-1784d758c644.png) And this is after my patch is applied. ![2015-01-13 14 38 13](https://cloud.githubusercontent.com/assets/4736016/5730797/359ea2da-9b32-11e4-97b0-544739ddbf4c.png) Author: Kousuke Saruta Closes #4028 from sarutak/SPARK-5228 and squashes the following commits: b1e6e8b [Kousuke Saruta] Fixed a small typo daab563 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-5228 9493a1d [Kousuke Saruta] Modified AllJobPage.scala so that hide Active Jobs/Completed Jobs/Failed Jobs when they are empty --- .../apache/spark/ui/jobs/AllJobsPage.scala | 62 +++++++++++++------ 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ea2d187a0e8e..1d1c70187844 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -47,7 +47,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } - val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") val duration: Option[Long] = { @@ -107,6 +107,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val failedJobsTable = jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + val shouldShowActiveJobs = activeJobs.nonEmpty + val shouldShowCompletedJobs = completedJobs.nonEmpty + val shouldShowFailedJobs = failedJobs.nonEmpty + val summary: NodeSeq =
          @@ -121,27 +125,47 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { Scheduling Mode: {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} -
        • - Active Jobs: - {activeJobs.size} -
        • -
        • - Completed Jobs: - {completedJobs.size} -
        • -
        • - Failed Jobs: - {failedJobs.size} -
        • + { + if (shouldShowActiveJobs) { +
        • + Active Jobs: + {activeJobs.size} +
        • + } + } + { + if (shouldShowCompletedJobs) { +
        • + Completed Jobs: + {completedJobs.size} +
        • + } + } + { + if (shouldShowFailedJobs) { +
        • + Failed Jobs: + {failedJobs.size} +
        • + } + }
        - val content = summary ++ -

        Active Jobs ({activeJobs.size})

        ++ activeJobsTable ++ -

        Completed Jobs ({completedJobs.size})

        ++ completedJobsTable ++ -

        Failed Jobs ({failedJobs.size})

        ++ failedJobsTable - - val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + var content = summary + if (shouldShowActiveJobs) { + content ++=

        Active Jobs ({activeJobs.size})

        ++ + activeJobsTable + } + if (shouldShowCompletedJobs) { + content ++=

        Completed Jobs ({completedJobs.size})

        ++ + completedJobsTable + } + if (shouldShowFailedJobs) { + content ++=

        Failed Jobs ({failedJobs.size})

        ++ + failedJobsTable + } + val helpText = """A job is triggered by an action, like "count()" or "saveAsTextFile()".""" + " Click on a job's title to see information about the stages of tasks associated with" + " the job." From 259936be710f367e1d1e019ee7b93fb68dfc33d0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 14 Jan 2015 11:45:40 -0800 Subject: [PATCH 144/215] [SPARK-4014] Add TaskContext.attemptNumber and deprecate TaskContext.attemptId `TaskContext.attemptId` is misleadingly-named, since it currently returns a taskId, which uniquely identifies a particular task attempt within a particular SparkContext, instead of an attempt number, which conveys how many times a task has been attempted. This patch deprecates `TaskContext.attemptId` and add `TaskContext.taskId` and `TaskContext.attemptNumber` fields. Prior to this change, it was impossible to determine whether a task was being re-attempted (or was a speculative copy), which made it difficult to write unit tests for tasks that fail on early attempts or speculative tasks that complete faster than original tasks. Earlier versions of the TaskContext docs suggest that `attemptId` behaves like `attemptNumber`, so there's an argument to be made in favor of changing this method's implementation. Since we've decided against making that change in maintenance branches, I think it's simpler to add better-named methods and retain the old behavior for `attemptId`; if `attemptId` behaved differently in different branches, then this would cause confusing build-breaks when backporting regression tests that rely on the new `attemptId` behavior. Most of this patch is fairly straightforward, but there is a bit of trickiness related to Mesos tasks: since there's no field in MesosTaskInfo to encode the attemptId, I packed it into the `data` field alongside the task binary. Author: Josh Rosen Closes #3849 from JoshRosen/SPARK-4014 and squashes the following commits: 89d03e0 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 5cfff05 [Josh Rosen] Introduce wrapper for serializing Mesos task launch data. 38574d4 [Josh Rosen] attemptId -> taskAttemptId in PairRDDFunctions a180b88 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 1d43aa6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 eee6a45 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 0b10526 [Josh Rosen] Use putInt instead of putLong (silly mistake) 8c387ce [Josh Rosen] Use local with maxRetries instead of local-cluster. cbe4d76 [Josh Rosen] Preserve attemptId behavior and deprecate it: b2dffa3 [Josh Rosen] Address some of Reynold's minor comments 9d8d4d1 [Josh Rosen] Doc typo 1e7a933 [Josh Rosen] [SPARK-4014] Change TaskContext.attemptId to return attempt number instead of task ID. fd515a5 [Josh Rosen] Add failing test for SPARK-4014 --- .../java/org/apache/spark/TaskContext.java | 24 +++++++++- .../org/apache/spark/TaskContextImpl.scala | 9 +++- .../CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/executor/Executor.scala | 17 +++++-- .../spark/executor/MesosExecutorBackend.scala | 5 +- .../org/apache/spark/rdd/CheckpointRDD.scala | 5 +- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 9 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../org/apache/spark/scheduler/Task.scala | 12 ++++- .../spark/scheduler/TaskDescription.scala | 1 + .../spark/scheduler/TaskSetManager.scala | 3 +- .../cluster/mesos/MesosSchedulerBackend.scala | 4 +- .../cluster/mesos/MesosTaskLaunchData.scala | 46 +++++++++++++++++++ .../spark/scheduler/local/LocalBackend.scala | 3 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 8 ++-- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 31 ++++++++++++- .../mesos/MesosSchedulerBackendSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +-- .../org/apache/spark/ui/UISeleniumSuite.scala | 2 +- project/MimaExcludes.scala | 6 +++ .../sql/parquet/ParquetTableOperations.scala | 5 +- .../hive/execution/InsertIntoHiveTable.scala | 5 +- 25 files changed, 168 insertions(+), 48 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 0d6973203eba..095f9fb94fdf 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -62,7 +62,7 @@ static void unset() { */ public abstract boolean isInterrupted(); - /** @deprecated: use isRunningLocally() */ + /** @deprecated use {@link #isRunningLocally()} */ @Deprecated public abstract boolean runningLocally(); @@ -87,19 +87,39 @@ static void unset() { * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * @deprecated: use addTaskCompletionListener + * @deprecated use {@link #addTaskCompletionListener(scala.Function1)} * * @param f Callback function. */ @Deprecated public abstract void addOnCompleteCallback(final Function0 f); + /** + * The ID of the stage that this task belong to. + */ public abstract int stageId(); + /** + * The ID of the RDD partition that is computed by this task. + */ public abstract int partitionId(); + /** + * How many times this task has been attempted. The first task attempt will be assigned + * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + */ + public abstract int attemptNumber(); + + /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */ + @Deprecated public abstract long attemptId(); + /** + * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts + * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. + */ + public abstract long taskAttemptId(); + /** ::DeveloperApi:: */ @DeveloperApi public abstract TaskMetrics taskMetrics(); diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index afd2b85d33a7..9bb0c61e441f 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce import scala.collection.mutable.ArrayBuffer -private[spark] class TaskContextImpl(val stageId: Int, +private[spark] class TaskContextImpl( + val stageId: Int, val partitionId: Int, - val attemptId: Long, + override val taskAttemptId: Long, + override val attemptNumber: Int, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext with Logging { + // For backwards-compatibility; this method is now deprecated as of 1.3.0. + override def attemptId: Long = taskAttemptId + // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c794a7bc3599..9a4adfbbb3d7 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -71,7 +71,8 @@ private[spark] class CoarseGrainedExecutorBackend( val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask) + executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, + taskDesc.name, taskDesc.serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 0f99cd9f3b08..b75c77b5b445 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -108,8 +108,13 @@ private[spark] class Executor( startDriverHeartbeater() def launchTask( - context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { - val tr = new TaskRunner(context, taskId, taskName, serializedTask) + context: ExecutorBackend, + taskId: Long, + attemptNumber: Int, + taskName: String, + serializedTask: ByteBuffer) { + val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, + serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } @@ -134,7 +139,11 @@ private[spark] class Executor( private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum class TaskRunner( - execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, + val taskId: Long, + val attemptNumber: Int, + taskName: String, + serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false @@ -180,7 +189,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) + val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 2e23ae0a4f83..cfd672e1d8a9 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -77,11 +78,13 @@ private[spark] class MesosExecutorBackend override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { val taskId = taskInfo.getTaskId.getValue.toLong + val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData) if (executor == null) { logError("Received launchTask but executor was null") } else { SparkHadoopUtil.get.runAsSparkUser { () => - executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) + executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber, + taskInfo.getName, taskData.serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 7ba1182f0ed2..1c13e2c37284 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging { val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) + val tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber) if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + @@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging { logInfo("Deleting tempOutputPath " + tempOutputPath) fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptId + " and final output path does not exist") + + ctx.attemptNumber + " and final output path does not exist") } else { // Some other copy of this task must've finished before us and renamed it logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 0001c2329c83..37e0c13029d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -229,7 +229,7 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 38f8f36a4a4d..e43e5066655b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -978,12 +978,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { val config = wrappedConf.value - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - attemptNumber) + context.attemptNumber) val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { @@ -1062,11 +1059,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) - writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() try { var recordsWritten = 0L diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 61d09d73e17c..8cb15918baa8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -635,8 +635,8 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = - new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, + attemptNumber = 0, runningLocally = true) TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d7dde4fe3843..2367f7e2cf67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -44,8 +44,16 @@ import org.apache.spark.util.Utils */ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { - final def run(attemptId: Long): T = { - context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false) + /** + * Called by Executor to run this task. + * + * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. + * @param attemptNumber how many times this task has been attempted (0 for the first attempt) + * @return the result of the task + */ + final def run(taskAttemptId: Long, attemptNumber: Int): T = { + context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, + taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 4c96b9e5fef6..1c7c81c488c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer */ private[spark] class TaskDescription( val taskId: Long, + val attemptNumber: Int, val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 466785091715..5c94c6bbcb37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,7 +487,8 @@ private[spark] class TaskSetManager( taskName, taskId, host, taskLocality, serializedTask.limit)) sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, + taskName, index, serializedTask)) } case _ => } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 10e6886c16a4..75d8ddf375e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -22,7 +22,7 @@ import java.util.{ArrayList => JArrayList, List => JList} import java.util.Collections import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler} @@ -296,7 +296,7 @@ private[spark] class MesosSchedulerBackend( .setExecutor(createExecutorInfo(slaveId)) .setName(task.name) .addResources(cpuResource) - .setData(ByteString.copyFrom(task.serializedTask)) + .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala new file mode 100644 index 000000000000..4416ce92ade2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.nio.ByteBuffer + +import org.apache.mesos.protobuf.ByteString + +/** + * Wrapper for serializing the data sent when launching Mesos tasks. + */ +private[spark] case class MesosTaskLaunchData( + serializedTask: ByteBuffer, + attemptNumber: Int) { + + def toByteString: ByteString = { + val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit) + dataBuffer.putInt(attemptNumber) + dataBuffer.put(serializedTask) + ByteString.copyFrom(dataBuffer) + } +} + +private[spark] object MesosTaskLaunchData { + def fromByteString(byteString: ByteString): MesosTaskLaunchData = { + val byteBuffer = byteString.asReadOnlyByteBuffer() + val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes + val serializedTask = byteBuffer.slice() // subsequence starting at the current position + MesosTaskLaunchData(serializedTask, attemptNumber) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index b3bd3110ac80..05b6fa54564b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -76,7 +76,8 @@ private[spark] class LocalActor( val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) + executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, + task.name, task.serializedTask) } } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 5ce299d05824..07b1e44d04be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -820,7 +820,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index c0735f448d19..d7d9dc7b50f3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 271a90c6646b..1a9a0e857e54 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 561a5e9cd90c..057e22691602 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -45,13 +45,13 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val task = new ResultTask[String, String]( 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0) + task.run(0, 0) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) @@ -63,6 +63,33 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte verify(listener, times(1)).onTaskCompletion(any()) } + + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { + sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks + // Check that attemptIds are 0 for all tasks' initial attempts + val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter => + Seq(TaskContext.get().attemptNumber).iterator + }.collect() + assert(attemptIds.toSet === Set(0)) + + // Test a job with failed tasks + val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter => + val attemptId = TaskContext.get().attemptNumber + if (iter.next() == 1 && attemptId == 0) { + throw new Exception("First execution of task failed") + } + Seq(attemptId).iterator + }.collect() + assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) + } + + test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") { + sc = new SparkContext("local", "test") + val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter => + Seq(TaskContext.get().attemptId).iterator + }.collect() + assert(attemptIds.toSet === Set(0, 1, 2, 3)) + } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index e60e70afd321..48f5e40f506d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -80,7 +80,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea mesosOffers.get(2).getHostname, 2 )) - val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() EasyMock.replay(taskScheduler) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 1eaabb93adbe..37b593b2c5f7 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0), + new TaskContextImpl(0, 0, 0, 0), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 787f4c2b5a8b..e85a436cdba1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -173,7 +173,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { // Simulate fetch failures: val mappedData = data.map { x => val taskContext = TaskContext.get - if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt. + if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt. val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f6f9f491f4ce..d3ea59424572 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,12 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") + ) ++ Seq( + // SPARK-4014 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.taskAttemptId"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.attemptNumber") ) case v if v.startsWith("1.2") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f5487740d3af..28cd17fde46a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -301,12 +301,9 @@ case class InsertIntoParquetTable( } def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - attemptNumber) + context.attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) val committer = format.getOutputCommitter(hadoopContext) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index ca0ec1513917..42bc8a0b6793 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -100,10 +100,7 @@ case class InsertIntoHiveTable( val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) + writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) iterator.foreach { row => var i = 0 From 2fd7f72b6b0b24bec12331c7bbbcf6bfc265d2ec Mon Sep 17 00:00:00 2001 From: Alex Baretta Date: Wed, 14 Jan 2015 11:51:55 -0800 Subject: [PATCH 145/215] [SPARK-5235] Make SQLConf Serializable Declare SQLConf to be serializable to fix "Task not serializable" exceptions in SparkSQL Author: Alex Baretta Closes #4031 from alexbaretta/SPARK-5235-SQLConf and squashes the following commits: c2103f5 [Alex Baretta] [SPARK-5235] Make SQLConf Serializable --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 206d16f5b34f..3bc201a2425b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -61,7 +61,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf { +private[sql] class SQLConf extends Serializable { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ From 76389c5b99183e456ff85fd92ea68d95c4c13e82 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 14 Jan 2015 11:53:43 -0800 Subject: [PATCH 146/215] [SPARK-5234][ml]examples for ml don't have sparkContext.stop JIRA issue: https://issues.apache.org/jira/browse/SPARK-5234 simply add the call. Author: Yuhao Yang Closes #4044 from hhbyyh/addscStop and squashes the following commits: c1f75ac [Yuhao Yang] add SparkContext.stop to 3 ml examples --- .../org/apache/spark/examples/ml/CrossValidatorExample.scala | 2 ++ .../org/apache/spark/examples/ml/SimpleParamsExample.scala | 2 ++ .../spark/examples/ml/SimpleTextClassificationPipeline.scala | 2 ++ 3 files changed, 6 insertions(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index ce6bc066bd70..d8c7ef38ee46 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -106,5 +106,7 @@ object CrossValidatorExample { .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) } + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 44d5b084c269..e8a2adff929c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -97,5 +97,7 @@ object SimpleParamsExample { .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) } + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 92895a05e479..b9a6ef0229de 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -85,5 +85,7 @@ object SimpleTextClassificationPipeline { .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) } + + sc.stop() } } From 13d2406781714daea2bbf3bfb7fec0dead10760c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 14 Jan 2015 17:50:33 -0800 Subject: [PATCH 147/215] [SPARK-5254][MLLIB] Update the user guide to position spark.ml better The current statement in the user guide may deliver confusing messages to users. spark.ml contains high-level APIs for building ML pipelines. But it doesn't mean that spark.mllib is being deprecated. First of all, the pipeline API is in its alpha stage and we need to see more use cases from the community to stabilizes it, which may take several releases. Secondly, the components in spark.ml are simple wrappers over spark.mllib implementations. Neither the APIs or the implementations from spark.mllib are being deprecated. We expect users use spark.ml pipeline APIs to build their ML pipelines, but we will keep supporting and adding features to spark.mllib. For example, there are many features in review at https://spark-prs.appspot.com/#mllib. So users should be comfortable with using spark.mllib features and expect more coming. The user guide needs to be updated to make the message clear. Author: Xiangrui Meng Closes #4052 from mengxr/SPARK-5254 and squashes the following commits: 6d5f1d3 [Xiangrui Meng] typo 0cc935b [Xiangrui Meng] update user guide to position spark.ml better --- docs/ml-guide.md | 17 ++++++++++------- docs/mllib-guide.md | 18 +++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 1c2e27341473..88158fd77eda 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,13 +3,16 @@ layout: global title: Spark ML Programming Guide --- -Spark ML is Spark's new machine learning package. It is currently an alpha component but is potentially a successor to [MLlib](mllib-guide.html). The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines. - -MLlib vs. Spark ML: - -* Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`. Since Spark ML is an alpha component, its API may change in future releases. -* Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. See below for more details. -* Spark ML only has Scala and Java APIs, whereas MLlib also has a Python API. +`spark.ml` is a new package introduced in Spark 1.2, which aims to provide a uniform set of +high-level APIs that help users create and tune practical machine learning pipelines. +It is currently an alpha component, and we would like to hear back from the community about +how it fits real-world use cases and how it could be improved. + +Note that we will keep supporting and adding features to `spark.mllib` along with the +development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.mllib` and can optionally contribute +to `spark.ml`. **Table of Contents** diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index efd7dda31071..39c64d06926b 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -35,16 +35,20 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -# spark.ml: The New ML Package +# spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new machine learning package called `spark.ml`, currently an alpha component but potentially a successor to `spark.mllib`. The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines. +Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +high-level APIs that help users create and tune practical machine learning pipelines. +It is currently an alpha component, and we would like to hear back from the community about +how it fits real-world use cases and how it could be improved. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. - -Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`. +Note that we will keep supporting and adding features to `spark.mllib` along with the +development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.mllib` and can optionally contribute +to `spark.ml`. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the `spark.ml` programming guide linked above for more details. +See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. # Dependencies From cfa397c126c857bfc9843d9e598a14b7c1e0457f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jan 2015 18:36:15 -0800 Subject: [PATCH 148/215] [SPARK-5193][SQL] Tighten up SQLContext API 1. Removed 2 implicits (logicalPlanToSparkQuery and baseRelationToSchemaRDD) 2. Moved extraStrategies into ExperimentalMethods. 3. Made private methods protected[sql] so they don't show up in javadocs. 4. Removed createParquetFile. 5. Added Java version of applySchema to SQLContext. Author: Reynold Xin Closes #4049 from rxin/sqlContext-refactor and squashes the following commits: a326a1a [Reynold Xin] Remove createParquetFile and add applySchema for Java to SQLContext. ecd6685 [Reynold Xin] Added baseRelationToSchemaRDD back. 4a38c9b [Reynold Xin] [SPARK-5193][SQL] Tighten up SQLContext API --- .../spark/sql/ExperimentalMethods.scala | 36 ++++ .../org/apache/spark/sql/SQLContext.scala | 152 ++++++++++------- .../apache/spark/sql/execution/commands.scala | 10 +- .../org/apache/spark/sql/sources/ddl.scala | 5 +- .../spark/sql/test/TestSQLContext.scala | 16 +- .../apache/spark/sql/InsertIntoSuite.scala | 160 ------------------ .../spark/sql/parquet/ParquetQuerySuite.scala | 26 --- .../sql/parquet/ParquetQuerySuite2.scala | 22 --- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../org/apache/spark/sql/hive/TestHive.scala | 2 +- 10 files changed, 152 insertions(+), 281 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala new file mode 100644 index 000000000000..f0e6a8f33218 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental + +/** + * Holder for experimental methods for the bravest. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + */ +@Experimental +class ExperimentalMethods protected[sql](sqlContext: SQLContext) { + + /** + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @Experimental + var extraStrategies: Seq[Strategy] = Nil + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d9f3b3a53f58..279671ced0a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql +import java.beans.Introspector import java.util.Properties import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ @@ -36,9 +37,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.json._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation, DDLParser, DataSourceStrategy} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * :: AlphaComponent :: @@ -59,7 +60,7 @@ class SQLContext(@transient val sparkContext: SparkContext) self => // Note that this is a lazy val so we can override the default value in subclasses. - private[sql] lazy val conf: SQLConf = new SQLConf + protected[sql] lazy val conf: SQLConf = new SQLConf /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = conf.setConf(props) @@ -117,15 +118,6 @@ class SQLContext(@transient val sparkContext: SparkContext) case _ => } - /** - * :: DeveloperApi :: - * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan - * interface is considered internal, and thus not guaranteed to be stable. As a result, using - * them directly is not recommended. - */ - @DeveloperApi - implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = new SchemaRDD(this, plan) - /** * Creates a SchemaRDD from an RDD of case classes. * @@ -139,8 +131,11 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) } - implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - logicalPlanToSparkQuery(LogicalRelation(baseRelation)) + /** + * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + */ + def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + new SchemaRDD(this, LogicalRelation(baseRelation)) } /** @@ -181,6 +176,43 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, logicalPlan) } + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + val attributeSeq = getSchema(beanClass) + val className = beanClass.getName + val rowRdd = rdd.mapPartitions { iter => + // BeanInfo is not serializable so we must rediscover it remotely for each partition. + val localBeanInfo = Introspector.getBeanInfo( + Class.forName(className, true, Utils.getContextOrSparkClassLoader)) + val extractors = + localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + + iter.map { row => + new GenericRow( + extractors.zip(attributeSeq).map { case (e, attr) => + DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + }.toArray[Any] + ) : Row + } + } + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + */ + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { + applySchema(rdd.rdd, beanClass) + } + /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * @@ -259,41 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext) applySchema(rowRDD, appliedSchema) } - /** - * :: Experimental :: - * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. - * This registered table can be used as the target of future `insertInto` operations. - * - * {{{ - * val sqlContext = new SQLContext(...) - * import sqlContext._ - * - * case class Person(name: String, age: Int) - * createParquetFile[Person]("path/to/file.parquet").registerTempTable("people") - * sql("INSERT INTO people SELECT 'michael', 29") - * }}} - * - * @tparam A A case class type that describes the desired schema of the parquet file to be - * created. - * @param path The path where the directory containing parquet metadata should be created. - * Data inserted into this table will also be stored at this location. - * @param allowExisting When false, an exception will be thrown if this directory already exists. - * @param conf A Hadoop configuration object that can be used to specify options to the parquet - * output format. - * - * @group userf - */ - @Experimental - def createParquetFile[A <: Product : TypeTag]( - path: String, - allowExisting: Boolean = true, - conf: Configuration = new Configuration()): SchemaRDD = { - new SchemaRDD( - this, - ParquetRelation.createEmpty( - path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) - } - /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. @@ -336,12 +333,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) /** - * :: DeveloperApi :: - * Allows extra strategies to be injected into the query planner at runtime. Note this API - * should be consider experimental and is not intended to be stable across releases. + * A collection of methods that are considered experimental, but can be used to hook into + * the query planner for advanced functionalities. */ - @DeveloperApi - var extraStrategies: Seq[Strategy] = Nil + val experimental: ExperimentalMethods = new ExperimentalMethods(this) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -353,7 +348,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.conf.numShufflePartitions def strategies: Seq[Strategy] = - extraStrategies ++ ( + experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: TakeOrdered :: @@ -479,14 +474,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * have the same format as the one generated by `toString` in scala. * It is only used by PySpark. */ - private[sql] def parseDataType(dataTypeString: String): DataType = { + protected[sql] def parseDataType(dataTypeString: String): DataType = { DataType.fromJson(dataTypeString) } /** * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. */ - private[sql] def applySchemaToPythonRDD( + protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schemaString: String): SchemaRDD = { val schema = parseDataType(schemaString).asInstanceOf[StructType] @@ -496,7 +491,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Apply a schema defined by the schema to an RDD. It is only used by PySpark. */ - private[sql] def applySchemaToPythonRDD( + protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { @@ -527,4 +522,43 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } + + /** + * Returns a Catalyst Schema for the given java bean class. + */ + protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + val beanInfo = Introspector.getBeanInfo(beanClass) + + // Note: The ordering of elements may differ from when the schema is inferred in Scala. + // This is because beanInfo.getPropertyDescriptors gives no guarantees about + // element ordering. + val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + fields.map { property => + val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + } + AttributeReference(property.getName, dataType, nullable)() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index af6b07bd6c2f..52a31f01a435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -137,14 +137,12 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - import sqlContext._ - - plan.foreach(_.registerTempTable(tableName)) - cacheTable(tableName) + plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + sqlContext.cacheTable(tableName) if (!isLazy) { // Performs eager caching - table(tableName).count() + sqlContext.table(tableName).count() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 4cc9641c4d9e..381298caba6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -22,7 +22,7 @@ import scala.util.parsing.combinator.syntactical.StandardTokenParsers import scala.util.parsing.combinator.PackratParsers import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SchemaRDD, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical import org.apache.spark.sql.execution.RunnableCommand @@ -234,8 +234,7 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - - sqlContext.baseRelationToSchemaRDD(resolved.relation).registerTempTable(tableName) + new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 8c80be106f3c..f9c082216085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.test +import scala.language.implicitConversions + import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ object TestSQLContext @@ -29,7 +32,16 @@ object TestSQLContext new SparkConf().set("spark.sql.testkey", "true"))) { /** Fewer partitions to speed up testing. */ - private[sql] override lazy val conf: SQLConf = new SQLConf { + protected[sql] override lazy val conf: SQLConf = new SQLConf { override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt } + + /** + * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to + * construct SchemaRDD directly out of local data without relying on implicits. + */ + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { + new SchemaRDD(this, plan) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala deleted file mode 100644 index c87d762751e6..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import _root_.java.io.File - -/* Implicits */ -import org.apache.spark.sql.test.TestSQLContext._ - -class InsertIntoSuite extends QueryTest { - TestData // Initialize TestData - import TestData._ - - test("insertInto() created parquet file") { - val testFilePath = File.createTempFile("sparkSql", "pqt") - testFilePath.delete() - testFilePath.deleteOnExit() - val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerTempTable("createAndInsertTest") - - // Add some data. - testData.insertInto("createAndInsertTest") - - // Make sure its there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq - ) - - // Add more data. - testData.insertInto("createAndInsertTest") - - // Make sure all data is there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq ++ testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq ++ testData.collect().toSeq - ) - - // Now overwrite. - testData.insertInto("createAndInsertTest", overwrite = true) - - // Make sure its there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertTest"), - testData.collect().toSeq - ) - - testFilePath.delete() - } - - test("INSERT INTO parquet table") { - val testFilePath = File.createTempFile("sparkSql", "pqt") - testFilePath.delete() - testFilePath.deleteOnExit() - val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerTempTable("createAndInsertSQLTest") - - sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") - - // Make sure its there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertSQLTest"), - testData.collect().toSeq - ) - - // Append more data. - sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") - - // Make sure all data is there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq ++ testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertSQLTest"), - testData.collect().toSeq ++ testData.collect().toSeq - ) - - sql("INSERT OVERWRITE INTO createAndInsertSQLTest SELECT * FROM testData") - - // Make sure its there for a new instance of parquet file. - checkAnswer( - parquetFile(testFilePath.getCanonicalPath), - testData.collect().toSeq - ) - - // Make sure the registered table has also been updated. - checkAnswer( - sql("SELECT * FROM createAndInsertSQLTest"), - testData.collect().toSeq - ) - - testFilePath.delete() - } - - test("Double create fails when allowExisting = false") { - val testFilePath = File.createTempFile("sparkSql", "pqt") - testFilePath.delete() - testFilePath.deleteOnExit() - val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - - intercept[RuntimeException] { - createParquetFile[TestData](testFilePath.getCanonicalPath, allowExisting = false) - } - - testFilePath.delete() - } - - test("Double create does not fail when allowExisting = true") { - val testFilePath = File.createTempFile("sparkSql", "pqt") - testFilePath.delete() - testFilePath.deleteOnExit() - val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - - createParquetFile[TestData](testFilePath.getCanonicalPath, allowExisting = true) - - testFilePath.delete() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index fe781ec05fb6..3a073a6b7057 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -402,23 +402,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(file) } - test("Insert (overwrite) via Scala API") { - val dirname = Utils.createTempDir() - val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - source_rdd.registerTempTable("source") - val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) - dest_rdd.registerTempTable("dest") - sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() - val rdd_copy1 = sql("SELECT * FROM dest").collect() - assert(rdd_copy1.size === 100) - - sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) - assert(rdd_copy2.size === 200) - Utils.deleteRecursively(dirname) - } - test("Insert (appending) to same table via Scala API") { sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() @@ -902,15 +885,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA Utils.deleteRecursively(tmpdir) } - test("Querying on empty parquet throws exception (SPARK-3536)") { - val tmpdir = Utils.createTempDir() - Utils.deleteRecursively(tmpdir) - createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") - val result1 = sql("SELECT * FROM tmpemptytable").collect() - assert(result1.size === 0) - Utils.deleteRecursively(tmpdir) - } - test("read/write fixed-length decimals") { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { val tempDir = getTempFilePath("parquetTest").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala index daa7ca65cd99..4c081fb4510b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite2.scala @@ -34,19 +34,6 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest { } } - test("insertion") { - withTempDir { dir => - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - createParquetFile[(Int, String)](dir.toString).registerTempTable("dest") - withTempTable("dest") { - sql("INSERT OVERWRITE INTO dest SELECT * FROM t") - checkAnswer(table("dest"), data) - } - } - } - } - test("appending") { val data = (0 until 10).map(i => (i, i.toString)) withParquetTable(data, "t") { @@ -98,13 +85,4 @@ class ParquetQuerySuite2 extends QueryTest with ParquetTest { checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } - - test("SPARK-3536 regression: query empty Parquet file shouldn't throw") { - withTempDir { dir => - createParquetFile[(Int, String)](dir.toString).registerTempTable("t") - withTempTable("t") { - checkAnswer(sql("SELECT * FROM t"), Seq.empty[Row]) - } - } - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index bf56e60cf995..a9a20a54bebe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -71,7 +71,7 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => - private[sql] override lazy val conf: SQLConf = new SQLConf { + protected[sql] override lazy val conf: SQLConf = new SQLConf { override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") } @@ -348,7 +348,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self - override def strategies: Seq[Strategy] = extraStrategies ++ Seq( + override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( DataSourceStrategy, HiveCommandStrategy(self), HiveDDLStrategy, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 52e1f0d94fbd..47431cef03e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -102,7 +102,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { new this.QueryExecution { val logical = plan } /** Fewer partitions to speed up testing. */ - private[sql] override lazy val conf: SQLConf = new SQLConf { + protected[sql] override lazy val conf: SQLConf = new SQLConf { override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") } From 6abc45e340d3be5f07236adc104db5f8dda0d514 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 14 Jan 2015 18:54:17 -0800 Subject: [PATCH 149/215] [SPARK-5254][MLLIB] remove developers section from spark.ml guide Forgot to remove this section in #4052. Author: Xiangrui Meng Closes #4053 from mengxr/SPARK-5254-update and squashes the following commits: f295bde [Xiangrui Meng] remove developers section from spark.ml guide --- docs/ml-guide.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 88158fd77eda..be178d7689fd 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -689,17 +689,3 @@ Spark ML currently depends on MLlib and has the same dependencies. Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info. Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. - -# Developers - -**Development plan** - -If all goes well, `spark.ml` will become the primary ML package at the time of the Spark 1.3 release. Initially, simple wrappers will be used to port algorithms to `spark.ml`, but eventually, code will be moved to `spark.ml` and `spark.mllib` will be deprecated. - -**Advice to developers** - -During the next development cycle, new algorithms should be contributed to `spark.mllib`, but we welcome patches sent to either package. If an algorithm is best expressed using the new API (e.g., feature transformers), we may ask for developers to use the new `spark.ml` API. -Wrappers for old and new algorithms can be contributed to `spark.ml`. - -Users will be able to use algorithms from either of the two packages. The main difficulty will be the differences in APIs between the two packages. - From 4b325c77a270ec32d6858d204313d4f161774fae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Jan 2015 20:31:02 -0800 Subject: [PATCH 150/215] [SPARK-5193][SQL] Tighten up HiveContext API 1. Removed the deprecated LocalHiveContext 2. Made private[sql] fields protected[sql] so they don't show up in javadoc. 3. Added javadoc to refreshTable. 4. Added Experimental tag to analyze command. Author: Reynold Xin Closes #4054 from rxin/hivecontext-api and squashes the following commits: 25cc00a [Reynold Xin] Add implicit conversion back. cbca886 [Reynold Xin] [SPARK-5193][SQL] Tighten up HiveContext API --- .../apache/spark/sql/hive/HiveContext.scala | 48 +++++-------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index a9a20a54bebe..4246b8b0918f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.io.{BufferedReader, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import scala.collection.JavaConversions._ @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} @@ -42,28 +43,6 @@ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTable import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ -/** - * DEPRECATED: Use HiveContext instead. - */ -@deprecated(""" - Use HiveContext instead. It will still create a local metastore if one is not specified. - However, note that the default directory is ./metastore_db, not ./metastore - """, "1.1") -class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { - - lazy val metastorePath = new File("metastore").getCanonicalPath - lazy val warehousePath: String = new File("warehouse").getCanonicalPath - - /** Sets up the system initially or after a RESET command */ - protected def configure() { - setConf("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$metastorePath;create=true") - setConf("hive.metastore.warehouse.dir", warehousePath) - } - - configure() // Must be called before initializing the catalog below. -} - /** * An instance of the Spark SQL execution engine that integrates with data stored in Hive. * Configuration for Hive is read from hive-site.xml on the classpath. @@ -80,7 +59,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive * SerDe. */ - private[spark] def convertMetastoreParquet: Boolean = + protected[sql] def convertMetastoreParquet: Boolean = getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = @@ -97,14 +76,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } - @deprecated("hiveql() is deprecated as the sql function now parses using HiveQL by default. " + - s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") - def hiveql(hqlQuery: String): SchemaRDD = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) - - @deprecated("hql() is deprecated as the sql function now parses using HiveQL by default. " + - s"The SQL dialect for parsing can be set using ${SQLConf.DIALECT}", "1.1") - def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) - /** * Creates a table using the schema of the given class. * @@ -116,6 +87,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting) } + /** + * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, + * Spark SQL or the external data source library it uses might cache certain metadata about a + * table, such as the location of blocks. When those change outside of Spark SQL, users should + * call this function to invalidate the cache. + */ def refreshTable(tableName: String): Unit = { // TODO: Database support... catalog.refreshTable("default", tableName) @@ -133,6 +110,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. */ + @Experimental def analyze(tableName: String) { val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName))) @@ -289,7 +267,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { results } - /** * Execute the command using Hive and return the results as a sequence. Each element * in the sequence is one row. @@ -345,7 +322,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } @transient - val hivePlanner = new SparkPlanner with HiveStrategies { + private val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( @@ -410,7 +387,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } -object HiveContext { + +private object HiveContext { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) From 3c8650c12ad7a97852e7bd76153210493fd83e92 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 15 Jan 2015 11:40:41 -0800 Subject: [PATCH 151/215] [SPARK-5224] [PySpark] improve performance of parallelize list/ndarray After the default batchSize changed to 0 (batched based on the size of object), but parallelize() still use BatchedSerializer with batchSize=1, this PR will use batchSize=1024 for parallelize by default. Also, BatchedSerializer did not work well with list and numpy.ndarray, this improve BatchedSerializer by using __len__ and __getslice__. Here is the benchmark for parallelize 1 millions int with list or ndarray: | before | after | improvements ------- | ------------ | ------------- | ------- list | 11.7 s | 0.8 s | 14x numpy.ndarray | 32 s | 0.7 s | 40x Author: Davies Liu Closes #4024 from davies/opt_numpy and squashes the following commits: 7618c7c [Davies Liu] improve performance of parallelize list/ndarray --- python/pyspark/context.py | 2 +- python/pyspark/serializers.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 593d74bca5ff..64f6a3ca6bf4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -319,7 +319,7 @@ def f(split, iterator): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bd08c9a6d20d..b8bda835174b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -181,6 +181,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) + elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): + n = len(iterator) + for i in xrange(0, n, self.batchSize): + yield iterator[i: i + self.batchSize] else: items = [] count = 0 From 1881431dd50e93a6948e4966d33742727f27e917 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 15 Jan 2015 16:15:12 -0800 Subject: [PATCH 152/215] [SPARK-5274][SQL] Reconcile Java and Scala UDFRegistration. As part of SPARK-5193: 1. Removed UDFRegistration as a mixin in SQLContext and made it a field ("udf"). 2. For Java UDFs, renamed dataType to returnType. 3. For Scala UDFs, added type tags. 4. Added all Java UDF registration methods to Scala's UDFRegistration. 5. Documentation Author: Reynold Xin Closes #4056 from rxin/udf-registration and squashes the following commits: ae9c556 [Reynold Xin] Updated example. 675a3c9 [Reynold Xin] Style fix 47c24ff [Reynold Xin] Python fix. 5f00c45 [Reynold Xin] Restore data type position in java udf and added typetags. 032f006 [Reynold Xin] [SPARK-5193][SQL] Reconcile Java and Scala UDFRegistration. --- python/pyspark/sql.py | 16 +- .../org/apache/spark/sql/SQLContext.scala | 29 +- .../apache/spark/sql/UdfRegistration.scala | 692 ++++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 9 +- .../spark/sql/UserDefinedTypeSuite.scala | 2 +- .../sql/hive/execution/HiveUdfSuite.scala | 2 +- 7 files changed, 674 insertions(+), 78 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 014ac1791c84..dcd3b60a6062 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1281,14 +1281,14 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) - self._ssql_ctx.registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_command), + env, + includes, + self._sc.pythonExec, + broadcast_vars, + self._sc._javaAccumulator, + returnType.json()) def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 279671ced0a1..8ad1753dab75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -54,7 +54,6 @@ class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with CacheManager with ExpressionConversions - with UDFRegistration with Serializable { self => @@ -338,6 +337,34 @@ class SQLContext(@transient val sparkContext: SparkContext) */ val experimental: ExperimentalMethods = new ExperimentalMethods(this) + /** + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sqlContext.udf().register("myUDF", + * new UDF2() { + * @Override + * public String call(Integer arg1, String arg2) { + * return arg2 + arg1; + * } + * }, DataTypes.StringType); + * }}} + * + * Or, to use Java 8 lambda syntax: + * {{{ + * sqlContext.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1), + * DataTypes.StringType); + * }}} + */ + val udf: UDFRegistration = new UDFRegistration(this) + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 5fb472686c9e..2e9d037f93c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -19,22 +19,26 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.Accumulator import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF +import org.apache.spark.sql.types.DataType -import scala.reflect.runtime.universe.{TypeTag, typeTag} /** - * Functions for registering scala lambda functions as UDFs in a SQLContext. + * Functions for registering user-defined functions. */ -private[sql] trait UDFRegistration { - self: SQLContext => +class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging { + + private val functionRegistry = sqlContext.functionRegistry - private[spark] def registerPython( + protected[sql] def registerPython( name: String, command: Array[Byte], envVars: JMap[String, String], @@ -55,7 +59,7 @@ private[sql] trait UDFRegistration { """.stripMargin) - val dataType = parseDataType(stringDataType) + val dataType = sqlContext.parseDataType(stringDataType) def builder(e: Seq[Expression]) = PythonUDF( @@ -72,133 +76,699 @@ private[sql] trait UDFRegistration { functionRegistry.registerFunction(name, builder) } - /** registerFunction 0-22 were generated by this script + // scalastyle:off + + /* registerFunction 0-22 were generated by this script (0 to 22).map { x => - val types = (1 to x).foldRight("T")((_, s) => {s"_, $s"}) - s""" - def registerFunction[T: TypeTag](name: String, func: Function$x[$types]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val argDocs = (1 to x).map(i => s" * @tparam A$i type of the UDF argument at position $i.").foldLeft("")(_ + "\n" + _) + println(s""" + /** + * Register a Scala closure of ${x} arguments as user-defined function (UDF). + * @tparam RT return type of UDF.$argDocs + */ + def register[$typeTags](name: String, func: Function$x[$types]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) - } - """ + }""") } - */ - // scalastyle:off - def registerFunction[T: TypeTag](name: String, func: Function0[T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + (1 to 22).foreach { i => + val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + println(s""" + |/** + | * Register a user-defined function with ${i} arguments. + | */ + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { + | functionRegistry.registerFunction( + | name, + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + |}""".stripMargin) + } + */ + + /** + * Register a Scala closure of 0 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + */ + def register[RT: TypeTag](name: String, func: Function0[RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 1 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + */ + def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 2 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 3 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 4 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 5 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 6 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 7 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 8 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 9 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 10 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 11 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 12 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 13 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 14 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 15 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 16 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 17 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 18 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + * @tparam A18 type of the UDF argument at position 18. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 19 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + * @tparam A18 type of the UDF argument at position 18. + * @tparam A19 type of the UDF argument at position 19. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 20 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + * @tparam A18 type of the UDF argument at position 18. + * @tparam A19 type of the UDF argument at position 19. + * @tparam A20 type of the UDF argument at position 20. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 21 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + * @tparam A18 type of the UDF argument at position 18. + * @tparam A19 type of the UDF argument at position 19. + * @tparam A20 type of the UDF argument at position 20. + * @tparam A21 type of the UDF argument at position 21. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } - def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + /** + * Register a Scala closure of 22 arguments as user-defined function (UDF). + * @tparam RT return type of UDF. + * @tparam A1 type of the UDF argument at position 1. + * @tparam A2 type of the UDF argument at position 2. + * @tparam A3 type of the UDF argument at position 3. + * @tparam A4 type of the UDF argument at position 4. + * @tparam A5 type of the UDF argument at position 5. + * @tparam A6 type of the UDF argument at position 6. + * @tparam A7 type of the UDF argument at position 7. + * @tparam A8 type of the UDF argument at position 8. + * @tparam A9 type of the UDF argument at position 9. + * @tparam A10 type of the UDF argument at position 10. + * @tparam A11 type of the UDF argument at position 11. + * @tparam A12 type of the UDF argument at position 12. + * @tparam A13 type of the UDF argument at position 13. + * @tparam A14 type of the UDF argument at position 14. + * @tparam A15 type of the UDF argument at position 15. + * @tparam A16 type of the UDF argument at position 16. + * @tparam A17 type of the UDF argument at position 17. + * @tparam A18 type of the UDF argument at position 18. + * @tparam A19 type of the UDF argument at position 19. + * @tparam A20 type of the UDF argument at position 20. + * @tparam A21 type of the UDF argument at position 21. + * @tparam A22 type of the UDF argument at position 22. + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[RT].dataType, e) functionRegistry.registerFunction(name, builder) } + + /** + * Register a user-defined function with 1 arguments. + */ + def register(name: String, f: UDF1[_, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + } + + /** + * Register a user-defined function with 2 arguments. + */ + def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 3 arguments. + */ + def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 4 arguments. + */ + def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 5 arguments. + */ + def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 6 arguments. + */ + def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 7 arguments. + */ + def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 8 arguments. + */ + def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 9 arguments. + */ + def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 10 arguments. + */ + def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 11 arguments. + */ + def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 12 arguments. + */ + def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 13 arguments. + */ + def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 14 arguments. + */ + def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 15 arguments. + */ + def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 16 arguments. + */ + def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 17 arguments. + */ + def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 18 arguments. + */ + def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 19 arguments. + */ + def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 20 arguments. + */ + def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 21 arguments. + */ + def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + + /** + * Register a user-defined function with 22 arguments. + */ + def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + } + // scalastyle:on } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cbdb3e64bb66..6c95bad6974d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -766,7 +766,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3371 Renaming a function expression with group by gives error") { - registerFunction("len", (s: String) => s.length) + udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 720953ae3765..0c9812003124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -27,23 +27,22 @@ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("Simple UDF") { - registerFunction("strLenScala", (_: String).length) + udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) } test("ZeroArgument UDF") { - registerFunction("random0", () => { Math.random()}) + udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - registerFunction("strLenScala", (_: String).length + (_:Int)) + udf.register("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } - test("struct UDF") { - registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result= sql("SELECT returnStruct('test', 'test2') as ret") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a0d54d17f5f1..fbc8704f7837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -81,7 +81,7 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 5fc8d8dbe3a9..5dafcd6c0a76 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -50,7 +50,7 @@ class HiveUdfSuite extends QueryTest { import TestHive._ test("spark sql udf test that returns a struct") { - registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) assert(sql( """ |SELECT getStruct(1).f1, From 65858ba555c4c0aea988b8f4b1c8476c42377eb9 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 15 Jan 2015 17:07:44 -0800 Subject: [PATCH 153/215] [Minor] Fix tiny typo in BlockManager In BlockManager, there is a word `BlockTranserService` but I think it's typo for `BlockTransferService`. Author: Kousuke Saruta Closes #4046 from sarutak/fix-tiny-typo and squashes the following commits: a3e2a2f [Kousuke Saruta] Fixed tiny typo in BlockManager --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d7b184f8a10e..1427305d91cf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -34,10 +34,9 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo -import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -120,7 +119,7 @@ private[spark] class BlockManager( private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the - // standard BlockTranserService to directly connect to other Executors. + // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) From 96c2c714f4f9abe20d4c42d99ffaafcb269714a1 Mon Sep 17 00:00:00 2001 From: Kostas Sakellis Date: Thu, 15 Jan 2015 17:53:42 -0800 Subject: [PATCH 154/215] [SPARK-4857] [CORE] Adds Executor membership events to SparkListener Adds onExecutorAdded and onExecutorRemoved events to the SparkListener. This will allow a client to get notified when an executor has been added/removed and provide additional information such as how many vcores it is consuming. In addition, this commit adds a SparkListenerAdapter to the Java API that provides default implementations to the SparkListener. This is to get around the fact that default implementations for traits don't work in Java. Having Java clients extend SparkListenerAdapter moving forward will prevent breakage in java when we add new events to SparkListener. Author: Kostas Sakellis Closes #3711 from ksakellis/kostas-spark-4857 and squashes the following commits: 946d2c5 [Kostas Sakellis] Added executorAdded/Removed events to MesosSchedulerBackend b1d054a [Kostas Sakellis] Remove executorInfo from ExecutorRemoved event 1727b38 [Kostas Sakellis] Renamed ExecutorDetails back to ExecutorInfo and other CR feedback 14fe78d [Kostas Sakellis] Added executor added/removed events to json protocol 93d087b [Kostas Sakellis] [SPARK-4857] [CORE] Adds Executor membership events to SparkListener --- .../org/apache/spark/JavaSparkListener.java | 97 +++++++++++++++++++ .../spark/deploy/master/ApplicationInfo.scala | 14 +-- ...{ExecutorInfo.scala => ExecutorDesc.scala} | 4 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../spark/deploy/master/WorkerInfo.scala | 6 +- .../deploy/master/ui/ApplicationPage.scala | 4 +- .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 22 ++++- .../spark/scheduler/SparkListenerBus.scala | 4 + .../CoarseGrainedSchedulerBackend.scala | 6 +- .../scheduler/cluster/ExecutorData.scala | 6 +- .../scheduler/cluster/ExecutorInfo.scala | 45 +++++++++ .../cluster/mesos/MesosSchedulerBackend.scala | 32 ++++-- .../org/apache/spark/util/JsonProtocol.scala | 40 +++++++- .../scheduler/EventLoggingListenerSuite.scala | 3 +- .../SparkListenerWithClusterSuite.scala | 62 ++++++++++++ .../mesos/MesosSchedulerBackendSuite.scala | 16 ++- .../apache/spark/util/JsonProtocolSuite.scala | 41 ++++++++ 18 files changed, 375 insertions(+), 33 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/JavaSparkListener.java rename core/src/main/scala/org/apache/spark/deploy/master/{ExecutorInfo.scala => ExecutorDesc.scala} (95%) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java new file mode 100644 index 000000000000..646496f31350 --- /dev/null +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.scheduler.SparkListener; +import org.apache.spark.scheduler.SparkListenerApplicationEnd; +import org.apache.spark.scheduler.SparkListenerApplicationStart; +import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; +import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; +import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; +import org.apache.spark.scheduler.SparkListenerExecutorAdded; +import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; +import org.apache.spark.scheduler.SparkListenerExecutorRemoved; +import org.apache.spark.scheduler.SparkListenerJobEnd; +import org.apache.spark.scheduler.SparkListenerJobStart; +import org.apache.spark.scheduler.SparkListenerStageCompleted; +import org.apache.spark.scheduler.SparkListenerStageSubmitted; +import org.apache.spark.scheduler.SparkListenerTaskEnd; +import org.apache.spark.scheduler.SparkListenerTaskGettingResult; +import org.apache.spark.scheduler.SparkListenerTaskStart; +import org.apache.spark.scheduler.SparkListenerUnpersistRDD; + +/** + * Java clients should extend this class instead of implementing + * SparkListener directly. This is to prevent java clients + * from breaking when new events are added to the SparkListener + * trait. + * + * This is a concrete class instead of abstract to enforce + * new events get added to both the SparkListener and this adapter + * in lockstep. + */ +public class JavaSparkListener implements SparkListener { + + @Override + public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { } + + @Override + public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { } + + @Override + public void onTaskStart(SparkListenerTaskStart taskStart) { } + + @Override + public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { } + + @Override + public void onTaskEnd(SparkListenerTaskEnd taskEnd) { } + + @Override + public void onJobStart(SparkListenerJobStart jobStart) { } + + @Override + public void onJobEnd(SparkListenerJobEnd jobEnd) { } + + @Override + public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { } + + @Override + public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { } + + @Override + public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { } + + @Override + public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { } + + @Override + public void onApplicationStart(SparkListenerApplicationStart applicationStart) { } + + @Override + public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { } + + @Override + public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { } + + @Override + public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } + + @Override + public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ad7d81747c37..ede0a9dbefb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -38,8 +38,8 @@ private[spark] class ApplicationInfo( extends Serializable { @transient var state: ApplicationState.Value = _ - @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _ - @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _ + @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _ + @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _ @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ @@ -55,12 +55,12 @@ private[spark] class ApplicationInfo( private def init() { state = ApplicationState.WAITING - executors = new mutable.HashMap[Int, ExecutorInfo] + executors = new mutable.HashMap[Int, ExecutorDesc] coresGranted = 0 endTime = -1L appSource = new ApplicationSource(this) nextExecutorId = 0 - removedExecutors = new ArrayBuffer[ExecutorInfo] + removedExecutors = new ArrayBuffer[ExecutorDesc] } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -75,14 +75,14 @@ private[spark] class ApplicationInfo( } } - def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = { - val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) executors(exec.id) = exec coresGranted += cores exec } - def removeExecutor(exec: ExecutorInfo) { + def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.id)) { removedExecutors += executors(exec.id) executors -= exec.id diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala rename to core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala index d417070c5101..5d620dfcabad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} -private[spark] class ExecutorInfo( +private[spark] class ExecutorDesc( val id: Int, val application: ApplicationInfo, val worker: WorkerInfo, @@ -37,7 +37,7 @@ private[spark] class ExecutorInfo( override def equals(other: Any): Boolean = { other match { - case info: ExecutorInfo => + case info: ExecutorDesc => fullId == info.fullId && worker.id == info.worker.id && cores == info.cores && diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4b631ec63907..d92d99310a58 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,7 @@ private[spark] class Master( } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 473ddc23ff0f..e94aae93e449 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -38,7 +38,7 @@ private[spark] class WorkerInfo( Utils.checkHost(host, "Expected hostname") assert (port > 0) - @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info + @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info @transient var state: WorkerState.Value = _ @transient var coresUsed: Int = _ @@ -70,13 +70,13 @@ private[spark] class WorkerInfo( host + ":" + port } - def addExecutor(exec: ExecutorInfo) { + def addExecutor(exec: ExecutorDesc) { executors(exec.fullId) = exec coresUsed += exec.cores memoryUsed += exec.memory } - def removeExecutor(exec: ExecutorInfo) { + def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.fullId)) { executors -= exec.fullId coresUsed -= exec.cores diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 4588c130ef43..3aae2b95d739 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -27,7 +27,7 @@ import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.ExecutorInfo +import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -109,7 +109,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app UIUtils.basicSparkPage(content, "Application: " + app.desc.name) } - private def executorRow(executor: ExecutorInfo): Seq[Node] = { + private def executorRow(executor: ExecutorDesc): Seq[Node] = {
    {executor.id} diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 27bf4f159907..30075c172bdb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -168,6 +168,10 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + override def onExecutorAdded(event: SparkListenerExecutorAdded) = + logEvent(event, flushLogger = true) + override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + logEvent(event, flushLogger = true) // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index b62b0c131269..4840d8bd2d2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{Distribution, Utils} @@ -84,6 +85,14 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerExecutorRemoved(executorId: String) + extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -109,7 +118,8 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal - * interface which might change in different Spark releases. + * interface which might change in different Spark releases. Java clients should extend + * {@link JavaSparkListener} */ @DeveloperApi trait SparkListener { @@ -183,6 +193,16 @@ trait SparkListener { * Called when the driver receives task metrics from an executor in a heartbeat. */ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } + + /** + * Called when the driver registers a new executor. + */ + def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { } + + /** + * Called when the driver removes an executor. + */ + def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e79ffd7a3587..e700c6af542f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -70,6 +70,10 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationEnd(applicationEnd)) case metricsUpdate: SparkListenerExecutorMetricsUpdate => foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) + case executorAdded: SparkListenerExecutorAdded => + foreachListener(_.onExecutorAdded(executorAdded)) + case executorRemoved: SparkListenerExecutorRemoved => + foreachListener(_.onExecutorRemoved(executorRemoved)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fe9914b50bc5..5786d367464f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -28,7 +28,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} -import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} @@ -66,6 +66,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Number of executors requested from the cluster manager that have not registered yet private var numPendingExecutors = 0 + private val listenerBus = scheduler.sc.listenerBus + // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] @@ -106,6 +108,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } + listenerBus.post(SparkListenerExecutorAdded(executorId, data)) makeOffers() } @@ -213,6 +216,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) + listenerBus.post(SparkListenerExecutorRemoved(executorId)) case None => logError(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index b71bd5783d6d..eb52ddfb1eab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -31,7 +31,7 @@ import akka.actor.{Address, ActorRef} private[cluster] class ExecutorData( val executorActor: ActorRef, val executorAddress: Address, - val executorHost: String , + override val executorHost: String, var freeCores: Int, - val totalCores: Int -) + override val totalCores: Int +) extends ExecutorInfo(executorHost, totalCores) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala new file mode 100644 index 000000000000..b4738e64c939 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about an executor to pass from the scheduler to SparkListeners. + */ +@DeveloperApi +class ExecutorInfo( + val executorHost: String, + val totalCores: Int +) { + + def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] + + override def equals(other: Any): Boolean = other match { + case that: ExecutorInfo => + (that canEqual this) && + executorHost == that.executorHost && + totalCores == that.totalCores + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(executorHost, totalCores) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 75d8ddf375e2..d252fe8595fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -27,9 +27,11 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, + ExecutorInfo => MesosExecutorInfo, _} import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -62,6 +64,9 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null + // The listener bus to publish executor added/removed events. + val listenerBus = sc.listenerBus + @volatile var appId: String = _ override def start() { @@ -87,7 +92,7 @@ private[spark] class MesosSchedulerBackend( } } - def createExecutorInfo(execId: String): ExecutorInfo = { + def createExecutorInfo(execId: String): MesosExecutorInfo = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { @@ -141,7 +146,7 @@ private[spark] class MesosSchedulerBackend( Value.Scalar.newBuilder() .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() - ExecutorInfo.newBuilder() + MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) @@ -237,6 +242,7 @@ private[spark] class MesosSchedulerBackend( } val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap + val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -260,6 +266,10 @@ private[spark] class MesosSchedulerBackend( val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? mesosTasks.foreach { case (slaveId, tasks) => + slaveIdToWorkerOffer.get(slaveId).foreach(o => + listenerBus.post(SparkListenerExecutorAdded(slaveId, + new ExecutorInfo(o.host, o.cores))) + ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -315,7 +325,7 @@ private[spark] class MesosSchedulerBackend( synchronized { if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - slaveIdsWithExecutors -= taskIdToSlaveId(tid) + removeExecutor(taskIdToSlaveId(tid)) } if (isFinished(status.getState)) { taskIdToSlaveId.remove(tid) @@ -344,12 +354,20 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + /** + * Remove executor associated with slaveId in a thread safe manner. + */ + private def removeExecutor(slaveId: String) = { + synchronized { + listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + slaveIdsWithExecutors -= slaveId + } + } + private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - slaveIdsWithExecutors -= slaveId.getValue - } + removeExecutor(slaveId.getValue) scheduler.executorLost(slaveId.getValue, reason) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d94e8252650d..a02501100615 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.util.{Properties, UUID} +import org.apache.spark.scheduler.cluster.ExecutorInfo + import scala.collection.JavaConverters._ import scala.collection.Map @@ -83,7 +85,10 @@ private[spark] object JsonProtocol { applicationStartToJson(applicationStart) case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - + case executorAdded: SparkListenerExecutorAdded => + executorAddedToJson(executorAdded) + case executorRemoved: SparkListenerExecutorRemoved => + executorRemovedToJson(executorRemoved) // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing case SparkListenerExecutorMetricsUpdate(_, _) => JNothing @@ -194,6 +199,16 @@ private[spark] object JsonProtocol { ("Timestamp" -> applicationEnd.time) } + def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { + ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Executor ID" -> executorAdded.executorId) ~ + ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) + } + + def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { + ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ + ("Executor ID" -> executorRemoved.executorId) + } /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | @@ -362,6 +377,10 @@ private[spark] object JsonProtocol { ("Disk Size" -> blockStatus.diskSize) } + def executorInfoToJson(executorInfo: ExecutorInfo): JValue = { + ("Host" -> executorInfo.executorHost) ~ + ("Total Cores" -> executorInfo.totalCores) + } /** ------------------------------ * * Util JSON serialization methods | @@ -416,6 +435,8 @@ private[spark] object JsonProtocol { val unpersistRDD = Utils.getFormattedClassName(SparkListenerUnpersistRDD) val applicationStart = Utils.getFormattedClassName(SparkListenerApplicationStart) val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd) + val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) + val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -431,6 +452,8 @@ private[spark] object JsonProtocol { case `unpersistRDD` => unpersistRDDFromJson(json) case `applicationStart` => applicationStartFromJson(json) case `applicationEnd` => applicationEndFromJson(json) + case `executorAdded` => executorAddedFromJson(json) + case `executorRemoved` => executorRemovedFromJson(json) } } @@ -523,6 +546,16 @@ private[spark] object JsonProtocol { SparkListenerApplicationEnd((json \ "Timestamp").extract[Long]) } + def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val executorId = (json \ "Executor ID").extract[String] + val executorInfo = executorInfoFromJson(json \ "Executor Info") + SparkListenerExecutorAdded(executorId, executorInfo) + } + + def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val executorId = (json \ "Executor ID").extract[String] + SparkListenerExecutorRemoved(executorId) + } /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | @@ -745,6 +778,11 @@ private[spark] object JsonProtocol { BlockStatus(storageLevel, memorySize, diskSize, tachyonSize) } + def executorInfoFromJson(json: JValue): ExecutorInfo = { + val executorHost = (json \ "Host").extract[String] + val totalCores = (json \ "Total Cores").extract[Int] + new ExecutorInfo(executorHost, totalCores) + } /** -------------------------------- * * Util JSON deserialization methods | diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 1de7e130039a..437d8693c0b1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -160,7 +160,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { val conf = getLoggingConf(testDirPath, compressionCodec) - val sc = new SparkContext("local", "test", conf) + val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val expectedLogDir = testDir.toURI().toString() @@ -184,6 +184,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin val eventSet = mutable.Set( SparkListenerApplicationStart, SparkListenerBlockManagerAdded, + SparkListenerExecutorAdded, SparkListenerEnvironmentUpdate, SparkListenerJobStart, SparkListenerJobEnd, diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala new file mode 100644 index 000000000000..623a687c359a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.{SparkContext, LocalSparkContext} + +import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll} + +import scala.collection.mutable + +/** + * Unit tests for SparkListener that require a local cluster. + */ +class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext + with BeforeAndAfter with BeforeAndAfterAll { + + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + + before { + sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + } + + test("SparkListener sends executor added message") { + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.map(_.toString) + rdd2.setName("Target RDD") + rdd2.count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(listener.addedExecutorInfo.size == 2) + assert(listener.addedExecutorInfo("0").totalCores == 1) + assert(listener.addedExecutorInfo("1").totalCores == 1) + } + + private class SaveExecutorInfo extends SparkListener { + val addedExecutorInfo = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfo(executor.executorId) = executor.executorInfo + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index 48f5e40f506d..78a30a40bf19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -18,17 +18,20 @@ package org.apache.spark.scheduler.mesos import org.scalatest.FunSuite -import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, + TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos._ -import org.scalatest.mock.EasyMockSugar +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} import org.apache.mesos.Protos.Value.Scalar import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer import java.util.Collections import java.util +import org.scalatest.mock.EasyMockSugar + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -52,11 +55,16 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea val driver = EasyMock.createMock(classOf[SchedulerDriver]) val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) + listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + EasyMock.replay(listenerBus) + val sc = EasyMock.createMock(classOf[SparkContext]) EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() + EasyMock.expect(sc.listenerBus).andReturn(listenerBus) EasyMock.replay(sc) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 63c2559c5c5f..5ba94ff67d39 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.shuffle.MetadataFetchFailedException import scala.collection.Map @@ -69,6 +70,9 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) + val executorAdded = SparkListenerExecutorAdded("exec1", + new ExecutorInfo("Hostee.awesome.com", 11)) + val executorRemoved = SparkListenerExecutorRemoved("exec2") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -85,6 +89,8 @@ class JsonProtocolSuite extends FunSuite { testEvent(unpersistRdd, unpersistRDDJsonString) testEvent(applicationStart, applicationStartJsonString) testEvent(applicationEnd, applicationEndJsonString) + testEvent(executorAdded, executorAddedJsonString) + testEvent(executorRemoved, executorRemovedJsonString) } test("Dependent Classes") { @@ -94,6 +100,7 @@ class JsonProtocolSuite extends FunSuite { testTaskMetrics(makeTaskMetrics( 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) + testExecutorInfo(new ExecutorInfo("host", 43)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -303,6 +310,10 @@ class JsonProtocolSuite extends FunSuite { assert(blockId === newBlockId) } + private def testExecutorInfo(info: ExecutorInfo) { + val newInfo = JsonProtocol.executorInfoFromJson(JsonProtocol.executorInfoToJson(info)) + assertEquals(info, newInfo) + } /** -------------------------------- * | Util methods for comparing events | @@ -335,6 +346,11 @@ class JsonProtocolSuite extends FunSuite { assertEquals(e1.jobResult, e2.jobResult) case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) + case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => + assert(e1.executorId == e1.executorId) + assertEquals(e1.executorInfo, e2.executorInfo) + case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => + assert(e1.executorId == e1.executorId) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -387,6 +403,11 @@ class JsonProtocolSuite extends FunSuite { assert(info1.accumulables === info2.accumulables) } + private def assertEquals(info1: ExecutorInfo, info2: ExecutorInfo) { + assert(info1.executorHost == info2.executorHost) + assert(info1.totalCores == info2.totalCores) + } + private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { assert(metrics1.hostname === metrics2.hostname) assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) @@ -1407,4 +1428,24 @@ class JsonProtocolSuite extends FunSuite { | "Timestamp": 42 |} """ + + private val executorAddedJsonString = + """ + |{ + | "Event": "SparkListenerExecutorAdded", + | "Executor ID": "exec1", + | "Executor Info": { + | "Host": "Hostee.awesome.com", + | "Total Cores": 11 + | } + |} + """ + + private val executorRemovedJsonString = + """ + |{ + | "Event": "SparkListenerExecutorRemoved", + | "Executor ID": "exec2" + |} + """ } From a79a9f923c47f2ce7da93cf0ecfe2b66fcb9fdd4 Mon Sep 17 00:00:00 2001 From: Kostas Sakellis Date: Thu, 15 Jan 2015 18:48:39 -0800 Subject: [PATCH 155/215] [SPARK-4092] [CORE] Fix InputMetrics for coalesce'd Rdds When calculating the input metrics there was an assumption that one task only reads from one block - this is not true for some operations including coalesce. This patch simply increments the task's input metrics if previous ones existed of the same read method. A limitation to this patch is that if a task reads from two different blocks of different read methods, one will override the other. Author: Kostas Sakellis Closes #3120 from ksakellis/kostas-spark-4092 and squashes the following commits: 54e6658 [Kostas Sakellis] Drops metrics if conflicting read methods exist f0e0cc5 [Kostas Sakellis] Add bytesReadCallback to InputMetrics a2a36d4 [Kostas Sakellis] CR feedback 5a0c770 [Kostas Sakellis] [SPARK-4092] [CORE] Fix InputMetrics for coalesce'd Rdds --- .../scala/org/apache/spark/CacheManager.scala | 6 +- .../org/apache/spark/executor/Executor.scala | 1 + .../apache/spark/executor/TaskMetrics.scala | 75 ++++++- .../org/apache/spark/rdd/HadoopRDD.scala | 39 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 40 ++-- .../apache/spark/storage/BlockManager.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 6 +- .../metrics/InputOutputMetricsSuite.scala | 195 ++++++++++++++---- .../ui/jobs/JobProgressListenerSuite.scala | 4 +- .../apache/spark/util/JsonProtocolSuite.scala | 4 +- 10 files changed, 270 insertions(+), 102 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 80da62c44edc..a0c0372b7f0e 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,7 +44,11 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics) + val inputMetrics = blockResult.inputMetrics + val existingMetrics = context.taskMetrics + .getInputMetricsForReadMethod(inputMetrics.readMethod) + existingMetrics.addBytesRead(inputMetrics.bytesRead) + new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) case None => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b75c77b5b445..6660b98eb8ce 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -379,6 +379,7 @@ private[spark] class Executor( if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => metrics.updateShuffleReadMetrics + metrics.updateInputMetrics() metrics.jvmGCTime = curGCTime - taskRunner.startGCTime if (isLocal) { // JobProgressListener will hold an reference of it during diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 51b5328cb4c8..7eb10f95e023 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,11 @@ package org.apache.spark.executor +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.DataReadMethod.DataReadMethod + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi @@ -80,7 +85,17 @@ class TaskMetrics extends Serializable { * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read * are stored here. */ - var inputMetrics: Option[InputMetrics] = None + private var _inputMetrics: Option[InputMetrics] = None + + def inputMetrics = _inputMetrics + + /** + * This should only be used when recreating TaskMetrics, not when updating input metrics in + * executors + */ + private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) { + _inputMetrics = inputMetrics + } /** * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much @@ -133,6 +148,30 @@ class TaskMetrics extends Serializable { readMetrics } + /** + * Returns the input metrics object that the task should use. Currently, if + * there exists an input metric with the same readMethod, we return that one + * so the caller can accumulate bytes read. If the readMethod is different + * than previously seen by this task, we return a new InputMetric but don't + * record it. + * + * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, + * we can store all the different inputMetrics (one per readMethod). + */ + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): + InputMetrics =synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } + } + /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ @@ -146,6 +185,10 @@ class TaskMetrics extends Serializable { } _shuffleReadMetrics = Some(merged) } + + private[spark] def updateInputMetrics() = synchronized { + inputMetrics.foreach(_.updateBytesRead()) + } } private[spark] object TaskMetrics { @@ -179,10 +222,38 @@ object DataWriteMethod extends Enumeration with Serializable { */ @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { + + private val _bytesRead: AtomicLong = new AtomicLong() + /** * Total bytes read. */ - var bytesRead: Long = 0L + def bytesRead: Long = _bytesRead.get() + @volatile @transient var bytesReadCallback: Option[() => Long] = None + + /** + * Adds additional bytes read for this read method. + */ + def addBytesRead(bytes: Long) = { + _bytesRead.addAndGet(bytes) + } + + /** + * Invoke the bytesReadCallback and mutate bytesRead. + */ + def updateBytesRead() { + bytesReadCallback.foreach { c => + _bytesRead.set(c()) + } + } + + /** + * Register a function that can be called to get up-to-date information on how many bytes the task + * has read from an input source. + */ + def setBytesReadCallback(f: Option[() => Long]) { + bytesReadCallback = f + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 37e0c13029d8..3b99d3a6cafd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,18 +213,19 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( - split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) - } else { - None - } - if (bytesReadCallback.isDefined) { - context.taskMetrics.inputMetrics = Some(inputMetrics) - } + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + split.inputSplit.value match { + case split: FileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _ => None + } + ) + inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) @@ -237,8 +238,6 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - var recordsSinceMetricsUpdate = 0 - override def getNext() = { try { finished = !reader.next(key, value) @@ -246,16 +245,6 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } - - // Update bytes read metric every few records - if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES - && bytesReadCallback.isDefined) { - recordsSinceMetricsUpdate = 0 - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() - } else { - recordsSinceMetricsUpdate += 1 - } (key, value) } @@ -263,14 +252,12 @@ class HadoopRDD[K, V]( try { reader.close() if (bytesReadCallback.isDefined) { - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() + inputMetrics.updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.bytesRead = split.inputSplit.value.getLength - context.taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e55d03d391e0..890ec677c269 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -109,18 +109,19 @@ class NewHadoopRDD[K, V]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( - split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) - } else { - None - } - if (bytesReadCallback.isDefined) { - context.taskMetrics.inputMetrics = Some(inputMetrics) - } + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + split.serializableHadoopSplit.value match { + case split: FileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _ => None + } + ) + inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) @@ -153,34 +154,19 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - - // Update bytes read metric every few records - if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES - && bytesReadCallback.isDefined) { - recordsSinceMetricsUpdate = 0 - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() - } else { - recordsSinceMetricsUpdate += 1 - } - (reader.getCurrentKey, reader.getCurrentValue) } private def close() { try { reader.close() - - // Update metrics with final amount if (bytesReadCallback.isDefined) { - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() + inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength - context.taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1427305d91cf..8bc5a1cd18b6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -53,7 +53,7 @@ private[spark] class BlockResult( readMethod: DataReadMethod.Value, bytes: Long) { val inputMetrics = new InputMetrics(readMethod) - inputMetrics.bytesRead = bytes + inputMetrics.addBytesRead(bytes) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a02501100615..ee3756c226fe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -637,8 +637,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) - metrics.inputMetrics = - Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson) + metrics.setInputMetrics( + Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson)) metrics.outputMetrics = Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson) metrics.updatedBlocks = @@ -671,7 +671,7 @@ private[spark] object JsonProtocol { def inputMetricsFromJson(json: JValue): InputMetrics = { val metrics = new InputMetrics( DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.bytesRead = (json \ "Bytes Read").extract[Long] + metrics.addBytesRead((json \ "Bytes Read").extract[Long]) metrics } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index f8bcde12a371..10a39990f80c 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -17,66 +17,185 @@ package org.apache.spark.metrics -import java.io.{FileWriter, PrintWriter, File} +import java.io.{File, FileWriter, PrintWriter} -import org.apache.spark.SharedSparkContext -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener} +import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import org.scalatest.Matchers import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SharedSparkContext +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.util.Utils + +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with Matchers { - test("input metrics when reading text file with single split") { - val file = new File(getClass.getSimpleName + ".txt") - val pw = new PrintWriter(new FileWriter(file)) - pw.println("some stuff") - pw.println("some other stuff") - pw.println("yet more stuff") - pw.println("too much stuff") + @transient var tmpDir: File = _ + @transient var tmpFile: File = _ + @transient var tmpFilePath: String = _ + + override def beforeAll() { + super.beforeAll() + + tmpDir = Utils.createTempDir() + val testTempDir = new File(tmpDir, "test") + testTempDir.mkdir() + + tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(tmpFile)) + for (x <- 1 to 1000000) { + pw.println("s") + } pw.close() - file.deleteOnExit() - val taskBytesRead = new ArrayBuffer[Long]() - sc.addSparkListener(new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead - } - }) - sc.textFile("file://" + file.getAbsolutePath, 2).count() + // Path to tmpFile + tmpFilePath = "file://" + tmpFile.getAbsolutePath + } - // Wait for task end events to come in - sc.listenerBus.waitUntilEmpty(500) - assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum >= file.length()) + override def afterAll() { + super.afterAll() + Utils.deleteRecursively(tmpDir) } - test("input metrics when reading text file with multiple splits") { - val file = new File(getClass.getSimpleName + ".txt") - val pw = new PrintWriter(new FileWriter(file)) - for (i <- 0 until 10000) { - pw.println("some stuff") + test("input metrics for old hadoop with coalesce") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).count() + } + val bytesRead2 = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).coalesce(2).count() + } + assert(bytesRead != 0) + assert(bytesRead == bytesRead2) + assert(bytesRead2 >= tmpFile.length()) + } + + test("input metrics with cache and coalesce") { + // prime the cache manager + val rdd = sc.textFile(tmpFilePath, 4).cache() + rdd.collect() + + val bytesRead = runAndReturnBytesRead { + rdd.count() + } + val bytesRead2 = runAndReturnBytesRead { + rdd.coalesce(4).count() } - pw.close() - file.deleteOnExit() + // for count and coelesce, the same bytes should be read. + assert(bytesRead != 0) + assert(bytesRead2 == bytesRead) + } + + /** + * This checks the situation where we have interleaved reads from + * different sources. Currently, we only accumulate fron the first + * read method we find in the task. This test uses cartesian to create + * the interleaved reads. + * + * Once https://issues.apache.org/jira/browse/SPARK-5225 is fixed + * this test should break. + */ + test("input metrics with mixed read method") { + // prime the cache manager + val numPartitions = 2 + val rdd = sc.parallelize(1 to 100, numPartitions).cache() + rdd.collect() + + val rdd2 = sc.textFile(tmpFilePath, numPartitions) + + val bytesRead = runAndReturnBytesRead { + rdd.count() + } + val bytesRead2 = runAndReturnBytesRead { + rdd2.count() + } + + val cartRead = runAndReturnBytesRead { + rdd.cartesian(rdd2).count() + } + + assert(cartRead != 0) + assert(bytesRead != 0) + // We read from the first rdd of the cartesian once per partition. + assert(cartRead == bytesRead * numPartitions) + } + + test("input metrics for new Hadoop API with coalesce") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).count() + } + val bytesRead2 = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).coalesce(5).count() + } + assert(bytesRead != 0) + assert(bytesRead2 == bytesRead) + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics when reading text file") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with interleaved reads") { + val numPartitions = 2 + val cartVector = 0 to 9 + val cartFile = new File(tmpDir, getClass.getSimpleName + "_cart.txt") + val cartFilePath = "file://" + cartFile.getAbsolutePath + + // write files to disk so we can read them later. + sc.parallelize(cartVector).saveAsTextFile(cartFilePath) + val aRdd = sc.textFile(cartFilePath, numPartitions) + + val tmpRdd = sc.textFile(tmpFilePath, numPartitions) + + val firstSize= runAndReturnBytesRead { + aRdd.count() + } + val secondSize = runAndReturnBytesRead { + tmpRdd.count() + } + + val cartesianBytes = runAndReturnBytesRead { + aRdd.cartesian(tmpRdd).count() + } + + // Computing the amount of bytes read for a cartesian operation is a little involved. + // Cartesian interleaves reads between two partitions eg. p1 and p2. + // Here are the steps: + // 1) First it creates an iterator for p1 + // 2) Creates an iterator for p2 + // 3) Reads the first element of p1 and then all the elements of p2 + // 4) proceeds to the next element of p1 + // 5) Creates a new iterator for p2 + // 6) rinse and repeat. + // As a result we read from the second partition n times where n is the number of keys in + // p1. Thus the math below for the test. + assert(cartesianBytes != 0) + assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize)) + } + + private def runAndReturnBytesRead(job : => Unit): Long = { val taskBytesRead = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead } }) - sc.textFile("file://" + file.getAbsolutePath, 2).count() - // Wait for task end events to come in + job + sc.listenerBus.waitUntilEmpty(500) - assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum >= file.length()) + taskBytesRead.sum } test("output metrics when writing text file") { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 12af60caf7d5..f865d8ca04d1 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -231,8 +231,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - taskMetrics.inputMetrics = Some(inputMetrics) - inputMetrics.bytesRead = base + 7 + taskMetrics.setInputMetrics(Some(inputMetrics)) + inputMetrics.addBytesRead(base + 7) val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) taskMetrics.outputMetrics = Some(outputMetrics) outputMetrics.bytesWritten = base + 8 diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 5ba94ff67d39..71dfed128985 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -630,8 +630,8 @@ class JsonProtocolSuite extends FunSuite { if (hasHadoopInput) { val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.bytesRead = d + e + f - t.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(d + e + f) + t.setInputMetrics(Some(inputMetrics)) } else { val sr = new ShuffleReadMetrics sr.remoteBytesRead = b + d From 2be82b1e66cd188456bbf1e5abb13af04d1629d5 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 16 Jan 2015 09:16:56 -0800 Subject: [PATCH 156/215] [SPARK-1507][YARN]specify # cores for ApplicationMaster Based on top of changes in https://github.com/apache/spark/pull/3806. https://issues.apache.org/jira/browse/SPARK-1507 `--driver-cores` and `spark.driver.cores` for all cluster modes and `spark.yarn.am.cores` for yarn client mode. Author: WangTaoTheTonic Author: WangTao Closes #4018 from WangTaoTheTonic/SPARK-1507 and squashes the following commits: 01419d3 [WangTaoTheTonic] amend the args name b255795 [WangTaoTheTonic] indet thing d86557c [WangTaoTheTonic] some comments amend 43c9392 [WangTao] fix compile error b39a100 [WangTao] specify # cores for ApplicationMaster --- .../apache/spark/deploy/ClientArguments.scala | 6 ++--- .../org/apache/spark/deploy/SparkSubmit.scala | 1 + .../spark/deploy/SparkSubmitArguments.scala | 5 ++++ docs/configuration.md | 15 ++++++++---- docs/running-on-yarn.md | 17 +++++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 1 + .../spark/deploy/yarn/ClientArguments.scala | 24 +++++++++++++++---- 7 files changed, 58 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 2e1e52906cee..e5873ce724b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{IntParam, MemoryParam} /** * Command-line parser for the driver client. @@ -51,8 +51,8 @@ private[spark] class ClientArguments(args: Array[String]) { parse(args.toList) def parse(args: List[String]): Unit = args match { - case ("--cores" | "-c") :: value :: tail => - cores = value.toInt + case ("--cores" | "-c") :: IntParam(value) :: tail => + cores = value parse(tail) case ("--memory" | "-m") :: MemoryParam(value) :: tail => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 955cbd6dab96..050ba91eb2bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -200,6 +200,7 @@ object SparkSubmit { // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), + OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 47059b08a397..81ec08cb6d50 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -108,6 +108,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) .orNull + driverCores = Option(driverCores) + .orElse(sparkProperties.get("spark.driver.cores")) + .orNull executorMemory = Option(executorMemory) .orElse(sparkProperties.get("spark.executor.memory")) .orElse(env.get("SPARK_EXECUTOR_MEMORY")) @@ -406,6 +409,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --total-executor-cores NUM Total cores for all executors. | | YARN-only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). diff --git a/docs/configuration.md b/docs/configuration.md index 673cdb371a51..efbab4085317 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -102,11 +102,10 @@ of the most common options to set are:
    spark.executor.memory512mspark.driver.cores1 - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Number of cores to use for the driver process, only in cluster mode.
    spark.executor.memory512m + Amount of memory to use per executor process, in the same format as JVM memory strings + (e.g. 512m, 2g). +
    spark.driver.maxResultSize 1g
    spark.driver.cores1 + Number of cores used by the driver in YARN cluster mode. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. +
    spark.yarn.am.cores1 + Number of cores to use for the YARN Application Master in client mode. + In cluster mode, use spark.driver.cores instead. +
    spark.yarn.am.waitTime 100000
    cogroup(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, Iterable<V>, Iterable<W>) tuples. This operation is also called groupWith. When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith.
    cartesian(otherDataset)
    + {formattedSubmissionTime} {formattedDuration}